diff --git a/config/config_jepa.yml b/config/config_jepa.yml index fc27da8c9..9b5495897 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -130,10 +130,39 @@ data_loading : # config for training training_config: - + # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["student_teacher"] + # Collapse monitoring for SSL training (JEPA/DINO/iBOT) + # Detects representation collapse via various metrics + # For forecasting, supports sequences of latents with per-step and aggregate metrics + collapse_monitoring: + enabled: true + compute_frequency: 100 # batches between metric computations + log_frequency: 100 # batches between metric logging + metrics: + effective_rank: + enabled: true + tensor_source: "both" # "student", "teacher", or "both" + sample_size: 2048 # max samples for SVD (0 = no sampling) + # For forecasting sequences: "all" (per-step + aggregates), + # "aggregate_only" (mean/min/max/degradation), "per_step_only" + forecast_aggregation: "all" + singular_values: + enabled: true + tensor_source: "both" + sample_size: 2048 + forecast_aggregation: "all" + dimension_variance: + enabled: true + tensor_source: "both" # cheap to compute, good early indicator + forecast_aggregation: "all" + prototype_entropy: + enabled: true # only applies to DINO + ema_beta: + enabled: true + num_mini_epochs: 32 samples_per_mini_epoch: 4096 shuffle: True @@ -182,8 +211,8 @@ training_config: }, }, target_and_aux_calc: { "EMATeacher" : - { ema_ramp_up_ratio : 0.09, - ema_halflife_in_thousands: 1e-3, + { ema_ramp_up_ratio : null, + ema_halflife_in_thousands: 1e-1, model_param_overrides : { training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} }, diff --git a/config/default_config.yml b/config/default_config.yml index 4bff1abfd..0d1dc5230 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -134,10 +134,39 @@ data_loading : # config for training training_config: - + # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["masking"] + # Collapse monitoring for detecting representation collapse + # Works with SSL training (JEPA/DINO) and forecasting modes + # For forecasting, monitors latent_state.patch_tokens at each forecast step + collapse_monitoring: + enabled: true + compute_frequency: 100 # batches between metric computations + log_frequency: 100 # batches between metric logging + metrics: + effective_rank: + enabled: true + tensor_source: "student" # "student", "teacher", or "both" + sample_size: 2048 # max samples for SVD (0 = no sampling) + # For forecasting sequences: "all" (per-step + aggregates), + # "aggregate_only" (mean/min/max/degradation), "per_step_only" + forecast_aggregation: "all" + singular_values: + enabled: true + tensor_source: "student" + sample_size: 2048 + forecast_aggregation: "all" + dimension_variance: + enabled: true + tensor_source: "student" + forecast_aggregation: "all" + prototype_entropy: + enabled: false # only relevant for DINO + ema_beta: + enabled: false # only relevant for SSL with EMA teacher + num_mini_epochs: 32 samples_per_mini_epoch: 4096 shuffle: True diff --git a/src/weathergen/model/ema.py b/src/weathergen/model/ema.py index 141947863..b42d756d4 100644 --- a/src/weathergen/model/ema.py +++ b/src/weathergen/model/ema.py @@ -30,6 +30,7 @@ def __init__( self.rampup_ratio = rampup_ratio self.ema_model = empty_model self.is_model_sharded = is_model_sharded + self.batch_size = 1 # Build a name → param map once self.src_params = dict(self.original_model.named_parameters()) @@ -55,16 +56,33 @@ def requires_grad_(self, flag: bool): for p in self.ema_model.parameters(): p.requires_grad = flag + def get_current_beta(self, cur_step: int) -> float: + """ + Get current EMA beta value for monitoring. + + The beta value determines how much the teacher model is updated towards + the student model at each step. Higher beta means slower teacher updates. + + Args: + cur_step: Current training step (typically istep * batch_size). + + Returns: + Current EMA beta value. + """ + halflife_steps = self.halflife_steps + if self.rampup_ratio is not None: + halflife_steps = min(halflife_steps, cur_step / self.rampup_ratio) + beta = 0.5 ** (self.batch_size / max(halflife_steps, 1e-6)) + return beta + @torch.no_grad() def update(self, cur_step, batch_size): # ensure model remains sharded if self.is_model_sharded: self.ema_model.reshard() # determine correct interpolation params - halflife_steps = self.halflife_steps - if self.rampup_ratio is not None: - halflife_steps = min(halflife_steps, cur_step / 1e3 * self.rampup_ratio) - beta = 0.5 ** (batch_size / max(halflife_steps * 1e3, 1e-6)) + self.batch_size = batch_size + beta = self.get_current_beta(cur_step) for name, p_ema in self.ema_model.named_parameters(): p_src = self.src_params.get(name, None) diff --git a/src/weathergen/train/collapse_monitor.py b/src/weathergen/train/collapse_monitor.py new file mode 100644 index 000000000..2ab2f1eb7 --- /dev/null +++ b/src/weathergen/train/collapse_monitor.py @@ -0,0 +1,604 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +Collapse monitoring metrics for SSL training (JEPA/DINO). + +This module implements metrics to detect representation collapse during self-supervised learning: +- Effective Rank (RankMe): Entropy of normalized singular values +- Singular Value Spectrum: Top-k singular values and concentration ratio +- Per-Dimension Variance: Min/mean/max variance across embedding dimensions +- Prototype Entropy: Normalized entropy of DINO prototype assignments +- EMA Beta: Current teacher momentum value + +For forecasting, the monitor supports sequences of latents (one per time step) and computes: +- Per-step metrics (e.g., effective_rank.step_0, step_1, ...) +- Aggregate metrics (mean, min across steps) +- Degradation ratio (final step / initial step) + +References: +- RankMe (ICML 2023): https://arxiv.org/abs/2210.02885 +- C-JEPA (NeurIPS 2024): https://arxiv.org/abs/2410.19560 +""" + +from __future__ import annotations + +import logging +from collections import defaultdict +from collections.abc import Callable +from typing import Any + +import torch + +logger = logging.getLogger(__name__) + +# Valid values for tensor_source config option +VALID_TENSOR_SOURCES = frozenset({"student", "teacher", "both"}) + +# Valid values for forecast_aggregation config option +VALID_FORECAST_AGGREGATIONS = frozenset({"all", "aggregate_only", "per_step_only"}) + + +class CollapseMonitor: + """ + Monitor for detecting representation collapse during SSL training. + + Computes and caches various collapse indicators that can be logged + at configurable intervals to minimize computational overhead. + + Supports both single latent tensors and sequences of latents for forecasting. + """ + + def __init__(self, config: dict[str, Any], device: torch.device) -> None: + """ + Initialize the collapse monitor. + + Args: + config: Configuration dictionary with collapse_monitoring settings. + device: Device to use for computations (currently unused, tensors + are processed on their original device). + + Raises: + ValueError: If config contains invalid values. + """ + self.enabled = config.get("enabled", False) + self.compute_frequency = config.get("compute_frequency", 100) + self.log_frequency = config.get("log_frequency", 100) + + # Validate frequencies + if self.compute_frequency <= 0: + raise ValueError(f"compute_frequency must be positive, got {self.compute_frequency}") + if self.log_frequency <= 0: + raise ValueError(f"log_frequency must be positive, got {self.log_frequency}") + + # Metric configurations + metrics_config = config.get("metrics", {}) + + self.effective_rank_config = metrics_config.get("effective_rank", {}) + self.singular_values_config = metrics_config.get("singular_values", {}) + self.dimension_variance_config = metrics_config.get("dimension_variance", {}) + self.prototype_entropy_config = metrics_config.get("prototype_entropy", {}) + self.ema_beta_config = metrics_config.get("ema_beta", {}) + + # Validate tensor_source values + self._validate_tensor_source(self.effective_rank_config, "effective_rank") + self._validate_tensor_source(self.singular_values_config, "singular_values") + self._validate_tensor_source(self.dimension_variance_config, "dimension_variance") + + # Validate forecast_aggregation values + self._validate_forecast_aggregation(self.effective_rank_config, "effective_rank") + self._validate_forecast_aggregation(self.singular_values_config, "singular_values") + self._validate_forecast_aggregation(self.dimension_variance_config, "dimension_variance") + + # Cache for accumulating metrics between log intervals + self._metrics_cache: defaultdict[str, list[float]] = defaultdict(list) + + def _validate_tensor_source(self, metric_config: dict[str, Any], metric_name: str) -> None: + """Validate tensor_source config value.""" + source = metric_config.get("tensor_source", "both") + if source not in VALID_TENSOR_SOURCES: + raise ValueError( + f"Invalid tensor_source '{source}' for {metric_name}. " + f"Must be one of: {sorted(VALID_TENSOR_SOURCES)}" + ) + + def _validate_forecast_aggregation( + self, metric_config: dict[str, Any], metric_name: str + ) -> None: + """Validate forecast_aggregation config value.""" + aggregation = metric_config.get("forecast_aggregation", "all") + if aggregation not in VALID_FORECAST_AGGREGATIONS: + raise ValueError( + f"Invalid forecast_aggregation '{aggregation}' for {metric_name}. " + f"Must be one of: {sorted(VALID_FORECAST_AGGREGATIONS)}" + ) + + def should_compute(self, step: int) -> bool: + """Check if metrics should be computed at this step.""" + return self.enabled and step % self.compute_frequency == 0 + + def should_log(self, step: int) -> bool: + """Check if metrics should be logged at this step.""" + return self.enabled and step % self.log_frequency == 0 + + def _get_tensors_to_monitor( + self, + student_latent: torch.Tensor | list[torch.Tensor] | None, + teacher_latent: torch.Tensor | list[torch.Tensor] | None, + metric_config: dict[str, Any], + ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: + """ + Get tensors to monitor based on metric config's tensor_source. + + Args: + student_latent: Student latent(s). + teacher_latent: Teacher latent(s). + metric_config: Config dict with tensor_source key. + + Returns: + Dict mapping "student"/"teacher" to their tensors (if requested). + """ + source = metric_config.get("tensor_source", "both") + result: dict[str, torch.Tensor | list[torch.Tensor] | None] = {} + + if source in ("student", "both"): + result["student"] = student_latent + if source in ("teacher", "both"): + result["teacher"] = teacher_latent + + return result + + def _compute_sequence_metrics( + self, + latents: list[torch.Tensor], + compute_fn: Callable[..., float], + metric_name: str, + aggregation: str, + **kwargs: Any, + ) -> dict[str, float]: + """ + Compute metrics for a sequence of latents (forecasting). + + Args: + latents: List of latent tensors, one per time step. + compute_fn: Function to compute metric for a single tensor. + metric_name: Base name for the metric (e.g., "effective_rank"). + aggregation: One of "all", "aggregate_only", "per_step_only". + **kwargs: Additional arguments passed to compute_fn. + + Returns: + Dictionary of metrics with per-step and/or aggregate values. + """ + metrics: dict[str, float] = {} + + if not latents: + return metrics + + # Compute per-step metrics + step_values: list[float] = [] + for step_idx, latent in enumerate(latents): + value = compute_fn(latent, **kwargs) + step_values.append(value) + + if aggregation in ("all", "per_step_only"): + metrics[f"{metric_name}.step_{step_idx}"] = value + + # Compute aggregate metrics + if aggregation in ("all", "aggregate_only") and step_values: + # Filter out invalid values (0.0 indicates computation failure) + valid_values = [v for v in step_values if v > 0] + + if valid_values: + metrics[f"{metric_name}.mean"] = sum(valid_values) / len(valid_values) + metrics[f"{metric_name}.min"] = min(valid_values) + metrics[f"{metric_name}.max"] = max(valid_values) + + # Degradation: ratio of last step to first step + # Values > 1 mean rank increased, < 1 means degradation + if step_values[0] > 0 and step_values[-1] > 0: + metrics[f"{metric_name}.degradation"] = step_values[-1] / step_values[0] + + return metrics + + def _compute_sequence_dict_metrics( + self, + latents: list[torch.Tensor], + compute_fn: Callable[..., dict[str, float]], + base_prefix: str, + aggregation: str, + **kwargs: Any, + ) -> dict[str, float]: + """ + Compute dict-returning metrics for a sequence of latents. + + For metrics like singular_values that return multiple values per tensor. + + Args: + latents: List of latent tensors. + compute_fn: Function returning dict of metrics for single tensor. + base_prefix: Prefix for metric names (e.g., "collapse.student"). + aggregation: One of "all", "aggregate_only", "per_step_only". + **kwargs: Additional arguments passed to compute_fn. + + Returns: + Dictionary of metrics. + """ + metrics: dict[str, float] = {} + + if not latents: + return metrics + + # Collect per-step values for each sub-metric + step_metrics: dict[str, list[float]] = defaultdict(list) + + for step_idx, latent in enumerate(latents): + step_result = compute_fn(latent, **kwargs) + + for key, value in step_result.items(): + step_metrics[key].append(value) + + if aggregation in ("all", "per_step_only"): + metrics[f"{base_prefix}.{key}.step_{step_idx}"] = value + + # Compute aggregates for each sub-metric + if aggregation in ("all", "aggregate_only"): + for key, values in step_metrics.items(): + valid_values = [v for v in values if v > 0 or key.startswith("var_")] + if valid_values: + metrics[f"{base_prefix}.{key}.mean"] = sum(valid_values) / len(valid_values) + metrics[f"{base_prefix}.{key}.min"] = min(valid_values) + metrics[f"{base_prefix}.{key}.max"] = max(valid_values) + + return metrics + + def compute_metrics( + self, + student_latent: torch.Tensor | list[torch.Tensor] | None = None, + teacher_latent: torch.Tensor | list[torch.Tensor] | None = None, + prototype_probs: torch.Tensor | None = None, + ema_beta: float | None = None, + loss_type: str | None = None, + ) -> dict[str, float]: + """ + Compute all enabled collapse monitoring metrics. + + Supports both single tensors and sequences of tensors (for forecasting). + For sequences, computes per-step metrics and aggregates based on config. + + Args: + student_latent: Student model latent representations. + Single tensor [B, N, D] or [B, D], or list of tensors for forecasting. + teacher_latent: Teacher model latent representations. + Single tensor [B, N, D] or [B, D], or list of tensors for forecasting. + prototype_probs: Post-softmax prototype assignment probabilities [B, K] (DINO only). + ema_beta: Current EMA momentum value. + loss_type: Type of SSL loss ("JEPA" or "DINO"). + + Returns: + Dictionary of computed metrics. + """ + if not self.enabled: + return {} + + metrics: dict[str, float] = {} + + # Compute effective rank + if self.effective_rank_config.get("enabled", True): + sample_size = self.effective_rank_config.get("sample_size", 2048) + aggregation = self.effective_rank_config.get("forecast_aggregation", "all") + tensors = self._get_tensors_to_monitor( + student_latent, teacher_latent, self.effective_rank_config + ) + + for name, tensor in tensors.items(): + if tensor is None: + continue + + if isinstance(tensor, list): + seq_metrics = self._compute_sequence_metrics( + tensor, + self._compute_effective_rank, + f"collapse.{name}.effective_rank", + aggregation, + sample_size=sample_size, + ) + metrics.update(seq_metrics) + else: + eff_rank = self._compute_effective_rank(tensor, sample_size) + metrics[f"collapse.{name}.effective_rank"] = eff_rank + + # Compute singular value spectrum + if self.singular_values_config.get("enabled", True): + sample_size = self.singular_values_config.get("sample_size", 2048) + aggregation = self.singular_values_config.get("forecast_aggregation", "all") + tensors = self._get_tensors_to_monitor( + student_latent, teacher_latent, self.singular_values_config + ) + + for name, tensor in tensors.items(): + if tensor is None: + continue + + if isinstance(tensor, list): + seq_metrics = self._compute_sequence_dict_metrics( + tensor, + self._compute_singular_values, + f"collapse.{name}", + aggregation, + sample_size=sample_size, + ) + metrics.update(seq_metrics) + else: + sv_metrics = self._compute_singular_values(tensor, sample_size) + for key, value in sv_metrics.items(): + metrics[f"collapse.{name}.{key}"] = value + + # Compute per-dimension variance + if self.dimension_variance_config.get("enabled", True): + aggregation = self.dimension_variance_config.get("forecast_aggregation", "all") + tensors = self._get_tensors_to_monitor( + student_latent, teacher_latent, self.dimension_variance_config + ) + + for name, tensor in tensors.items(): + if tensor is None: + continue + + if isinstance(tensor, list): + seq_metrics = self._compute_sequence_dict_metrics( + tensor, + self._compute_dimension_variance, + f"collapse.{name}", + aggregation, + ) + metrics.update(seq_metrics) + else: + var_metrics = self._compute_dimension_variance(tensor) + for key, value in var_metrics.items(): + metrics[f"collapse.{name}.{key}"] = value + + # Compute prototype entropy (DINO only) + if ( + self.prototype_entropy_config.get("enabled", True) + and prototype_probs is not None + and loss_type == "DINO" + ): + entropy = self._compute_prototype_entropy(prototype_probs) + metrics["collapse.dino.prototype_entropy"] = entropy + + # Log EMA beta + if self.ema_beta_config.get("enabled", True) and ema_beta is not None: + metrics["collapse.ema_beta"] = ema_beta + + # Cache metrics for averaging + for key, value in metrics.items(): + self._metrics_cache[key].append(value) + + return metrics + + def get_cached_metrics(self) -> dict[str, float]: + """ + Get averaged cached metrics and clear the cache. + + Returns: + Dictionary of averaged metrics since last call. + """ + averaged_metrics: dict[str, float] = {} + for key, values in self._metrics_cache.items(): + if values: + averaged_metrics[key] = sum(values) / len(values) + + self._metrics_cache.clear() + return averaged_metrics + + def _flatten_to_samples(self, z: torch.Tensor) -> torch.Tensor: + """ + Flatten patch dimension into sample dimension. + + Treats [B, N, D] as [B*N, D] where each patch is an independent sample. + This is consistent with C-JEPA/VICReg approach. + + Args: + z: Tensor of shape [B, N, D] or [B, D]. + + Returns: + Tensor of shape [B*N, D] or [B, D]. + """ + # Convert to float32 for SVD compatibility (bfloat16/float16 can fail) + if z.dtype in (torch.bfloat16, torch.float16): + z = z.float() + + if z.ndim == 3: + return z.reshape(-1, z.shape[-1]) + return z + + def _sample_rows(self, z: torch.Tensor, sample_size: int) -> torch.Tensor: + """ + Randomly sample rows to reduce SVD computation cost. + + Args: + z: Tensor of shape [N, D]. + sample_size: Maximum number of samples (0 = no sampling). + + Returns: + Sampled tensor of shape [min(N, sample_size), D]. + """ + if sample_size <= 0 or z.shape[0] <= sample_size: + return z + + indices = torch.randperm(z.shape[0], device=z.device)[:sample_size] + return z[indices] + + def _compute_effective_rank(self, z: torch.Tensor, sample_size: int = 2048) -> float: + """ + Compute effective rank via entropy of normalized singular values (RankMe). + + The effective rank measures how many dimensions are actually being used + in the representation. A low effective rank indicates collapse. + + Args: + z: Latent representations [B, N, D] or [B, D]. + sample_size: Maximum samples for SVD computation. + + Returns: + Effective rank (exp of entropy of normalized singular values). + """ + z = self._flatten_to_samples(z.detach()) + z = self._sample_rows(z, sample_size) + + # Validate tensor before SVD + if z.numel() == 0: + logger.warning("Empty tensor in effective rank computation") + return 0.0 + if torch.isnan(z).any() or torch.isinf(z).any(): + logger.warning("NaN/Inf values in tensor for effective rank computation") + return 0.0 + if z.shape[0] < 2 or z.shape[1] < 2: + logger.warning(f"Tensor too small for SVD: shape={z.shape}") + return 0.0 + + # Center the data + z_centered = z - z.mean(dim=0, keepdim=True) + + # Compute SVD + try: + _, s, _ = torch.linalg.svd(z_centered, full_matrices=False) + except RuntimeError as e: + # SVD can fail on degenerate matrices + logger.warning(f"SVD failed in effective rank computation: {e}, shape={z.shape}") + return 0.0 + + # Normalize singular values to get a probability distribution + s_normalized = s / (s.sum() + 1e-8) + + # Compute entropy + entropy = -torch.sum(s_normalized * torch.log(s_normalized + 1e-8)) + + # Effective rank is exp(entropy) + effective_rank = torch.exp(entropy) + + return effective_rank.item() + + def _compute_singular_values( + self, z: torch.Tensor, sample_size: int = 2048 + ) -> dict[str, float]: + """ + Compute singular value statistics and concentration ratio. + + The concentration ratio (top SV / sum of all SVs) indicates how much + variance is captured by the largest singular value. High concentration + suggests dimensional collapse. + + Args: + z: Latent representations [B, N, D] or [B, D]. + sample_size: Maximum samples for SVD computation. + + Returns: + Dictionary with sv_min, sv_max, sv_mean, and sv_concentration. + """ + z = self._flatten_to_samples(z.detach()) + z = self._sample_rows(z, sample_size) + + # Validate tensor before SVD + if z.numel() == 0: + logger.warning("Empty tensor in singular value computation") + return {} + if torch.isnan(z).any() or torch.isinf(z).any(): + logger.warning("NaN/Inf values in tensor for singular value computation") + return {} + if z.shape[0] < 2 or z.shape[1] < 2: + logger.warning(f"Tensor too small for SVD: shape={z.shape}") + return {} + + # Center the data + z_centered = z - z.mean(dim=0, keepdim=True) + + # Compute SVD + try: + _, s, _ = torch.linalg.svd(z_centered, full_matrices=False) + except RuntimeError as e: + logger.warning(f"SVD failed in singular value computation: {e}, shape={z.shape}") + return {} + + metrics: dict[str, float] = {} + + # Singular value statistics + metrics["sv_min"] = s.min().item() + metrics["sv_max"] = s.max().item() + metrics["sv_mean"] = s.mean().item() + + # Concentration ratio (top SV / sum) + s_sum = s.sum() + 1e-8 + metrics["sv_concentration"] = (s[0] / s_sum).item() + + return metrics + + def _compute_dimension_variance(self, z: torch.Tensor) -> dict[str, float]: + """ + Compute per-dimension variance statistics. + + Low minimum variance indicates "dead" dimensions that are not being used. + Large variance ratio (max/min) suggests imbalanced dimension usage. + + Args: + z: Latent representations [B, N, D] or [B, D]. + + Returns: + Dictionary with var_min, var_mean, var_max. Empty dict if tensor is invalid. + """ + z = self._flatten_to_samples(z.detach()) + + # Validate tensor + if z.numel() == 0: + logger.warning("Empty tensor in dimension variance computation") + return {} + if torch.isnan(z).any() or torch.isinf(z).any(): + logger.warning("NaN/Inf values in tensor for dimension variance computation") + return {} + if z.shape[0] < 2: + logger.warning(f"Need at least 2 samples to compute variance: shape={z.shape}") + return {} + + # Compute variance along sample dimension + var_per_dim = z.var(dim=0) + + return { + "var_min": var_per_dim.min().item(), + "var_mean": var_per_dim.mean().item(), + "var_max": var_per_dim.max().item(), + } + + def _compute_prototype_entropy(self, probs: torch.Tensor) -> float: + """ + Compute normalized entropy of DINO prototype assignments. + + Low entropy indicates collapse to few prototypes. Entropy is normalized + to [0, 1] range where 1 means uniform distribution. + + Args: + probs: Post-softmax prototype assignment probabilities [B, K]. + + Returns: + Normalized entropy in [0, 1]. + """ + probs = probs.detach() + + # Average across batch to get prototype usage distribution + avg_probs = probs.mean(dim=0) + + # Compute entropy + entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-8)) + + # Normalize by maximum possible entropy (uniform distribution) + num_prototypes = probs.shape[1] + max_entropy = torch.log(torch.tensor(float(num_prototypes), device=probs.device)) + + normalized_entropy = entropy / (max_entropy + 1e-8) + + return normalized_entropy.item() diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index cfb252f86..76994221c 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -77,6 +77,10 @@ def to_device(self, device) -> EMATeacher: module.to(device) return self + def get_current_beta(self, cur_step: int) -> float: + beta = self.ema_model.get_current_beta(cur_step) + return beta + def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): return_dict = {} diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index e949dc1cc..fda4eab60 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -24,13 +24,16 @@ from weathergen.common.config import Config, merge_configs from weathergen.datasets.multi_stream_data_sampler import MultiStreamDataSampler from weathergen.model.ema import EMAModel +from weathergen.model.engines import LatentState from weathergen.model.model_interface import ( get_target_aux_calculator, init_model_and_shard, ) from weathergen.model.utils import apply_fct_to_blocks, set_to_eval +from weathergen.train.collapse_monitor import CollapseMonitor from weathergen.train.loss_calculator import LossCalculator from weathergen.train.lr_scheduler import LearningRateScheduler +from weathergen.train.target_and_aux_ssl_teacher import EMATeacher from weathergen.train.trainer_base import TrainerBase from weathergen.train.utils import ( extract_batch_metadata, @@ -74,6 +77,7 @@ def __init__(self, train_log_freq: Config): self.batch_size_per_gpu = -1 self.batch_size_validation_per_gpu = -1 self.batch_size_test_per_gpu = -1 + self.collapse_monitor: CollapseMonitor | None = None def get_batch_size_total(self, batch_size_per_gpu) -> int: """ @@ -146,6 +150,10 @@ def init(self, cf: Config, devices): self.train_logger = TrainLogger(cf, config.get_path_run(self.cf)) + # Initialize collapse monitor for SSL training + collapse_config = self.training_cfg.get("collapse_monitoring", {}) + self.collapse_monitor = CollapseMonitor(collapse_config, None) # device set later in run() + def get_target_aux_calculators(self, mode_cfg): """ Get target_aux_calculators for given mode_cfg @@ -227,6 +235,9 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): device_type = torch.accelerator.current_accelerator() self.device = torch.device(f"{device_type}:{cf.local_rank}") + # Update collapse monitor device + self.collapse_monitor.device = self.device + # create data loaders self.dataset = MultiStreamDataSampler(cf, self.training_cfg, stage=TRAIN) self.dataset_val = MultiStreamDataSampler(cf, self.validation_cfg, stage=VAL) @@ -501,9 +512,16 @@ def train(self, mini_epoch): if self.validate_with_ema: self.ema_model.update(self.cf.general.istep * batch_size_total, batch_size_total) + # Compute collapse monitoring metrics + if self.collapse_monitor.should_compute(self.cf.general.istep): + self._compute_collapse_metrics(preds, targets_and_auxs) + self._log_terminal(bidx, mini_epoch, TRAIN) if bidx % self.train_log_freq.metrics == 0: self._log(TRAIN) + # Log collapse metrics + if self.collapse_monitor.should_log(self.cf.general.istep): + self._log_collapse_metrics(TRAIN) # save model checkpoint (with designation _latest) if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: @@ -775,3 +793,155 @@ def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): logger.info("\n") self.t_start = time.time() + + def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: + """ + Extract latent tensors from predictions and targets, then compute collapse metrics. + + This method extracts the student and teacher latent representations from the + model outputs. It supports two modes: + + 1. SSL training (JEPA/DINO/iBOT): Extracts latents from SSL-specific keys + 2. Forecasting: Extracts latents from 'latent_state' at each forecast step + + For forecasting, a list of latent tensors is passed to enable per-step metrics. + """ + student_latent: torch.Tensor | list[torch.Tensor] | None = None + teacher_latent: torch.Tensor | list[torch.Tensor] | None = None + prototype_probs = None + ema_beta = None + loss_type = None + + # Helper to extract tensor from various latent formats + def extract_latent_tensor( + latent_data: torch.Tensor | LatentState | list | dict | None, + ) -> torch.Tensor | None: + """Extract tensor from various latent data formats.""" + if latent_data is None: + return None + if isinstance(latent_data, torch.Tensor): + return latent_data + if isinstance(latent_data, LatentState): + # Use patch_tokens as the primary latent representation + # For forecast steps > 0, patch_tokens is None, so fall back to z_pre_norm + if latent_data.patch_tokens is not None: + return latent_data.patch_tokens + if latent_data.z_pre_norm is not None: + # z_pre_norm includes register/class tokens, extract patch tokens only + # This assumes the same token layout as patch_tokens + return latent_data.z_pre_norm + return None + if isinstance(latent_data, list) and len(latent_data) > 0: + return extract_latent_tensor(latent_data[0]) + if isinstance(latent_data, dict): + # Try common keys + for key in ["latent", "patch_tokens", "z_pre_norm"]: + if key in latent_data: + return extract_latent_tensor(latent_data[key]) + return None + + # Find SSL loss type and extract teacher latents + for _loss_name, target_aux in targets_and_auxs.items(): + if not hasattr(target_aux, "latent") or not target_aux.latent: + continue + + # Handle both list[dict] and dict formats + if isinstance(target_aux.latent, list): + target_latent_dict = target_aux.latent[0] if target_aux.latent else {} + else: + target_latent_dict = target_aux.latent + + # Try SSL-specific keys first + for ssl_type in ["JEPA", "DINO", "iBOT"]: + if ssl_type in target_latent_dict: + loss_type = ssl_type + teacher_latent = extract_latent_tensor(target_latent_dict[ssl_type]) + break + + # Extract student latents from predictions + if preds.latent and len(preds.latent) > 0: + # First, try SSL-specific keys (JEPA/DINO/iBOT) from first step + pred_latent_dict = preds.latent[0] + for ssl_type in ["JEPA", "DINO", "iBOT"]: + if ssl_type in pred_latent_dict: + student_latent = extract_latent_tensor(pred_latent_dict[ssl_type]) + loss_type = ssl_type + break + + # If no SSL keys found, extract from latent_state for all forecast steps + if student_latent is None: + student_latents_list: list[torch.Tensor] = [] + for step_latent_dict in preds.latent: + if "latent_state" in step_latent_dict: + step_tensor = extract_latent_tensor(step_latent_dict["latent_state"]) + if step_tensor is not None: + student_latents_list.append(step_tensor) + + # Use list if multiple steps, single tensor otherwise + if len(student_latents_list) > 1: + student_latent = student_latents_list + n_steps = len(student_latents_list) + logger.debug(f"Collapse monitor - forecasting mode: {n_steps} steps") + elif len(student_latents_list) == 1: + student_latent = student_latents_list[0] + + # Get EMA beta from target_and_aux_calculators + for _calc_name, calculator in self.target_and_aux_calculators.items(): + if isinstance(calculator, EMATeacher): + batch_size_total = self.get_batch_size_total(self.batch_size_per_gpu) + step = batch_size_total * self.cf.general.istep + ema_beta = calculator.get_current_beta(step) + break + + # Debug logging + if student_latent is not None: + if isinstance(student_latent, list): + shapes = [t.shape for t in student_latent] + logger.debug(f"Collapse monitor - student (list): {len(shapes)} steps") + else: + logger.debug(f"Collapse monitor - student: shape={student_latent.shape}") + else: + logger.debug("Collapse monitor - student_latent is None") + + if teacher_latent is not None: + if isinstance(teacher_latent, list): + logger.debug(f"Collapse monitor - teacher (list): {len(teacher_latent)} steps") + else: + shape = teacher_latent.shape if isinstance(teacher_latent, torch.Tensor) else "N/A" + logger.debug(f"Collapse monitor - teacher: shape={shape}") + + # Compute metrics if we have valid student latent + has_valid_latent = student_latent is not None and ( + isinstance(student_latent, torch.Tensor) + or (isinstance(student_latent, list) and len(student_latent) > 0) + ) + + if has_valid_latent: + # Prepare teacher latent (must match student format if provided) + teacher_for_metrics = None + if teacher_latent is not None: + is_valid_tensor = isinstance(teacher_latent, torch.Tensor) + is_valid_list = isinstance(teacher_latent, list) and len(teacher_latent) > 0 + if is_valid_tensor or is_valid_list: + teacher_for_metrics = teacher_latent + + self.collapse_monitor.compute_metrics( + student_latent=student_latent, + teacher_latent=teacher_for_metrics, + prototype_probs=prototype_probs, + ema_beta=ema_beta, + loss_type=loss_type, + ) + else: + logger.debug( + f"Collapse monitor - skipping compute_metrics: " + f"student_latent is {'None' if student_latent is None else type(student_latent)}" + ) + + def _log_collapse_metrics(self, stage: Stage) -> None: + """ + Log cached collapse monitoring metrics. + """ + metrics = self.collapse_monitor.get_cached_metrics() + if metrics and is_root(): + self.train_logger.log_metrics(stage, metrics) diff --git a/tests/test_collapse_monitor.py b/tests/test_collapse_monitor.py new file mode 100644 index 000000000..e5fc75985 --- /dev/null +++ b/tests/test_collapse_monitor.py @@ -0,0 +1,878 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Unit tests for collapse monitoring metrics.""" + +import pytest +import torch + +from weathergen.train.collapse_monitor import CollapseMonitor + + +@pytest.fixture +def default_config(): + """Default enabled config for collapse monitoring.""" + return { + "enabled": True, + "compute_frequency": 100, + "log_frequency": 100, + "metrics": { + "effective_rank": { + "enabled": True, + "tensor_source": "both", + "sample_size": 2048, + }, + "singular_values": { + "enabled": True, + "tensor_source": "both", + "sample_size": 2048, + }, + "dimension_variance": { + "enabled": True, + "tensor_source": "both", + }, + "prototype_entropy": { + "enabled": True, + }, + "ema_beta": { + "enabled": True, + }, + }, + } + + +@pytest.fixture +def monitor(default_config): + """Create a collapse monitor with default config.""" + device = torch.device("cpu") + return CollapseMonitor(default_config, device) + + +class TestCollapseMonitorInitialization: + """Test CollapseMonitor initialization.""" + + def test_disabled_monitor(self): + """Test that disabled monitor doesn't compute metrics.""" + config = {"enabled": False} + monitor = CollapseMonitor(config, torch.device("cpu")) + assert not monitor.enabled + assert not monitor.should_compute(100) + assert not monitor.should_log(100) + + def test_enabled_monitor(self, default_config): + """Test that enabled monitor computes at correct intervals.""" + monitor = CollapseMonitor(default_config, torch.device("cpu")) + assert monitor.enabled + assert monitor.should_compute(0) + assert monitor.should_compute(100) + assert not monitor.should_compute(50) + + def test_frequency_settings(self): + """Test custom frequency settings.""" + config = { + "enabled": True, + "compute_frequency": 50, + "log_frequency": 200, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + assert monitor.should_compute(50) + assert monitor.should_compute(100) # 100 is a multiple of 50 + assert not monitor.should_compute(75) # 75 is not a multiple of 50 + assert monitor.should_log(200) + assert not monitor.should_log(100) + + +class TestEffectiveRank: + """Test effective rank computation.""" + + def test_full_rank_matrix(self, monitor): + """Full rank random matrix should have effective rank close to min(N, D).""" + torch.manual_seed(42) + # Create a full-rank matrix with orthogonal columns + dim = 64 + num_samples = 128 + z = torch.randn(num_samples, dim) + # Make it more orthogonal via QR decomposition + q, _ = torch.linalg.qr(z.T) + z = q.T # Now z is [dim, dim] with orthogonal rows + z = torch.cat([z, torch.randn(num_samples - dim, dim)], dim=0) + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # For a full-rank matrix, effective rank should be significant portion of D + assert eff_rank > dim * 0.3, f"Expected effective rank > {dim * 0.3}, got {eff_rank}" + + def test_low_rank_matrix(self, monitor): + """Low rank matrix should have effective rank close to actual rank.""" + torch.manual_seed(42) + # Create a rank-5 matrix + actual_rank = 5 + num_samples, dim = 128, 64 + u_mat = torch.randn(num_samples, actual_rank) + v_mat = torch.randn(actual_rank, dim) + z = u_mat @ v_mat + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # Effective rank should be close to actual rank + assert eff_rank < actual_rank * 2, ( + f"Expected effective rank < {actual_rank * 2}, got {eff_rank}" + ) + assert eff_rank > actual_rank * 0.5, ( + f"Expected effective rank > {actual_rank * 0.5}, got {eff_rank}" + ) + + def test_collapsed_matrix(self, monitor): + """Completely collapsed matrix should have effective rank ~1.""" + num_samples, dim = 128, 64 + # All rows are the same (rank 1) + row = torch.randn(1, dim) + z = row.expand(num_samples, dim).clone() + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # Effective rank should be very close to 1 + assert eff_rank < 2, f"Expected effective rank < 2, got {eff_rank}" + + def test_3d_tensor_flattening(self, monitor): + """Test that [B, N, D] tensors are properly flattened.""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + z = torch.randn(batch_size, num_patches, dim) + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # Should compute without error and return reasonable value + assert 1 <= eff_rank <= dim + + +class TestSingularValues: + """Test singular value spectrum computation.""" + + def test_singular_value_statistics(self, monitor): + """Test that singular value statistics are correctly computed.""" + torch.manual_seed(42) + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + + sv_metrics = monitor._compute_singular_values(z, sample_size=0) + + # Check that we got min, max, mean statistics + assert "sv_min" in sv_metrics + assert "sv_max" in sv_metrics + assert "sv_mean" in sv_metrics + assert "sv_concentration" in sv_metrics + + # Max should be >= mean >= min + assert sv_metrics["sv_max"] >= sv_metrics["sv_mean"] + assert sv_metrics["sv_mean"] >= sv_metrics["sv_min"] + + def test_concentration_ratio(self, monitor): + """Test singular value concentration ratio.""" + torch.manual_seed(42) + # Create a rank-1 matrix where first SV dominates + num_samples, dim = 128, 64 + # Use outer product to create a truly rank-1 dominated matrix + u_vec = torch.randn(num_samples, 1) + v_vec = torch.randn(1, dim) + z = u_vec @ v_vec * 10 + torch.randn(num_samples, dim) * 0.01 # Strong rank-1 component + + sv_metrics = monitor._compute_singular_values(z, sample_size=0) + + # Concentration should be high when one SV dominates + assert "sv_concentration" in sv_metrics + assert sv_metrics["sv_concentration"] > 0.8 # First SV dominates strongly + + # Max should be much larger than min for rank-1 dominated matrix + assert sv_metrics["sv_max"] > sv_metrics["sv_min"] * 10 + + def test_uniform_singular_values(self, monitor): + """Test with random matrix (spread singular values).""" + torch.manual_seed(42) + # Random matrix will have spread singular values + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + + sv_metrics = monitor._compute_singular_values(z, sample_size=0) + + # Concentration should be relatively low for random matrix + assert sv_metrics["sv_concentration"] < 0.2 + + # All statistics should be positive + assert sv_metrics["sv_min"] > 0 + assert sv_metrics["sv_max"] > 0 + assert sv_metrics["sv_mean"] > 0 + + +class TestDimensionVariance: + """Test per-dimension variance computation.""" + + def test_random_matrix_balanced_variance(self, monitor): + """Random matrix should have balanced variance across dimensions.""" + torch.manual_seed(42) + num_samples, dim = 1024, 64 + z = torch.randn(num_samples, dim) + + var_metrics = monitor._compute_dimension_variance(z) + + # All variances should be close to 1 for standard normal + assert abs(var_metrics["var_mean"] - 1.0) < 0.2 + # Variance ratio should be small for random matrix + var_ratio = var_metrics["var_max"] / (var_metrics["var_min"] + 1e-8) + assert var_ratio < 5 # Balanced dimensions + + def test_dead_dimensions(self, monitor): + """Test detection of dead (zero-variance) dimensions.""" + torch.manual_seed(42) + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + # Kill some dimensions (set to constant) + z[:, :10] = 0.5 + + var_metrics = monitor._compute_dimension_variance(z) + + # Minimum variance should be very close to 0 (dead dimensions) + assert var_metrics["var_min"] < 1e-6 + + def test_imbalanced_dimensions(self, monitor): + """Test with highly imbalanced dimension variances.""" + torch.manual_seed(42) + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + # Scale some dimensions much more than others + z[:, 0] *= 100 + z[:, 1:10] *= 0.01 + + var_metrics = monitor._compute_dimension_variance(z) + + # Large variance ratio indicates imbalance + var_ratio = var_metrics["var_max"] / (var_metrics["var_min"] + 1e-8) + assert var_ratio > 1000 + + +class TestPrototypeEntropy: + """Test DINO prototype entropy computation.""" + + def test_uniform_prototype_distribution(self, monitor): + """Uniform prototype distribution should have entropy ~1.""" + batch_size, num_prototypes = 128, 64 + # Uniform distribution + probs = torch.ones(batch_size, num_prototypes) / num_prototypes + + entropy = monitor._compute_prototype_entropy(probs) + + # Normalized entropy should be close to 1 + assert abs(entropy - 1.0) < 0.01 + + def test_single_prototype_collapse(self, monitor): + """Collapse to single prototype should have entropy ~0.""" + batch_size, num_prototypes = 128, 64 + # All mass on first prototype + probs = torch.zeros(batch_size, num_prototypes) + probs[:, 0] = 1.0 + + entropy = monitor._compute_prototype_entropy(probs) + + # Normalized entropy should be close to 0 + assert entropy < 0.01 + + def test_partial_collapse(self, monitor): + """Partial collapse should have intermediate entropy.""" + batch_size, num_prototypes = 128, 64 + # Only 4 prototypes used uniformly (much stronger collapse) + probs = torch.zeros(batch_size, num_prototypes) + probs[:, :4] = 0.25 # Only 4 out of 64 prototypes + + entropy = monitor._compute_prototype_entropy(probs) + + # Entropy should be between 0 and 1 (log(4)/log(64) ≈ 0.33) + assert 0.2 < entropy < 0.5 + + +class TestMetricsCaching: + """Test metrics caching and averaging.""" + + def test_cache_accumulation(self, monitor): + """Test that metrics are properly cached.""" + torch.manual_seed(42) + z1 = torch.randn(64, 32) + z2 = torch.randn(64, 32) + + # Compute metrics twice + monitor.compute_metrics(student_latent=z1) + monitor.compute_metrics(student_latent=z2) + + # Cache should contain averaged values + cached = monitor.get_cached_metrics() + assert "collapse.student.effective_rank" in cached + + def test_cache_clear(self, monitor): + """Test that cache is cleared after get_cached_metrics.""" + torch.manual_seed(42) + z = torch.randn(64, 32) + + monitor.compute_metrics(student_latent=z) + _ = monitor.get_cached_metrics() + + # Second call should return empty + cached = monitor.get_cached_metrics() + assert len(cached) == 0 + + +class TestIntegration: + """Integration tests with both student and teacher tensors.""" + + def test_full_metrics_computation(self, monitor): + """Test computing all metrics with both student and teacher.""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + student = torch.randn(batch_size, num_patches, dim) + teacher = torch.randn(batch_size, num_patches, dim) + + metrics = monitor.compute_metrics( + student_latent=student, + teacher_latent=teacher, + ema_beta=0.999, + loss_type="JEPA", + ) + + # Check that both student and teacher metrics are computed + assert "collapse.student.effective_rank" in metrics + assert "collapse.teacher.effective_rank" in metrics + assert "collapse.student.var_min" in metrics + assert "collapse.teacher.var_min" in metrics + assert "collapse.ema_beta" in metrics + assert metrics["collapse.ema_beta"] == 0.999 + + def test_dino_prototype_entropy(self, monitor): + """Test DINO prototype entropy computation.""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + num_prototypes = 128 + student = torch.randn(batch_size, num_patches, dim) + probs = torch.softmax(torch.randn(batch_size, num_prototypes), dim=-1) + + metrics = monitor.compute_metrics( + student_latent=student, + prototype_probs=probs, + loss_type="DINO", + ) + + assert "collapse.dino.prototype_entropy" in metrics + assert 0 <= metrics["collapse.dino.prototype_entropy"] <= 1 + + def test_disabled_metrics(self): + """Test that disabled metrics are not computed.""" + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": {"enabled": False}, + "singular_values": {"enabled": False}, + "dimension_variance": {"enabled": True, "tensor_source": "student"}, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + + torch.manual_seed(42) + z = torch.randn(64, 32) + metrics = monitor.compute_metrics(student_latent=z) + + # Only dimension variance should be computed + assert "collapse.student.var_min" in metrics + assert "collapse.student.effective_rank" not in metrics + assert "collapse.student.sv_max" not in metrics + + +class TestSampling: + """Test row sampling for SVD computations.""" + + def test_sampling_reduces_computation(self, monitor): + """Test that sampling works for large tensors.""" + torch.manual_seed(42) + num_samples, dim = 10000, 64 + z = torch.randn(num_samples, dim) + + # With sampling + eff_rank_sampled = monitor._compute_effective_rank(z, sample_size=1024) + # Without sampling + eff_rank_full = monitor._compute_effective_rank(z, sample_size=0) + + # Results should be in same ballpark + assert abs(eff_rank_sampled - eff_rank_full) < eff_rank_full * 0.3 + + def test_no_sampling_when_small(self, monitor): + """Test that small tensors aren't sampled.""" + torch.manual_seed(42) + num_samples, dim = 100, 64 + z = torch.randn(num_samples, dim) + + # Sample size larger than N + sampled = monitor._sample_rows(z, sample_size=1024) + assert sampled.shape[0] == num_samples # No sampling occurred + + +class TestConfigValidation: + """Test configuration validation.""" + + def test_invalid_compute_frequency(self): + """Test that non-positive compute_frequency raises error.""" + config = {"enabled": True, "compute_frequency": 0} + with pytest.raises(ValueError, match="compute_frequency must be positive"): + CollapseMonitor(config, torch.device("cpu")) + + config = {"enabled": True, "compute_frequency": -1} + with pytest.raises(ValueError, match="compute_frequency must be positive"): + CollapseMonitor(config, torch.device("cpu")) + + def test_invalid_log_frequency(self): + """Test that non-positive log_frequency raises error.""" + config = {"enabled": True, "log_frequency": 0} + with pytest.raises(ValueError, match="log_frequency must be positive"): + CollapseMonitor(config, torch.device("cpu")) + + def test_invalid_tensor_source(self): + """Test that invalid tensor_source raises error.""" + config = { + "enabled": True, + "metrics": { + "effective_rank": {"tensor_source": "invalid_source"}, + }, + } + with pytest.raises(ValueError, match="Invalid tensor_source"): + CollapseMonitor(config, torch.device("cpu")) + + def test_invalid_forecast_aggregation(self): + """Test that invalid forecast_aggregation raises error.""" + config = { + "enabled": True, + "metrics": { + "effective_rank": {"forecast_aggregation": "invalid_agg"}, + }, + } + with pytest.raises(ValueError, match="Invalid forecast_aggregation"): + CollapseMonitor(config, torch.device("cpu")) + + def test_valid_tensor_sources(self): + """Test that valid tensor_source values are accepted.""" + for source in ["student", "teacher", "both"]: + config = { + "enabled": True, + "metrics": { + "effective_rank": {"tensor_source": source}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + assert monitor is not None + + def test_valid_forecast_aggregations(self): + """Test that valid forecast_aggregation values are accepted.""" + for agg in ["all", "aggregate_only", "per_step_only"]: + config = { + "enabled": True, + "metrics": { + "effective_rank": {"forecast_aggregation": agg}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + assert monitor is not None + + +class TestDimensionVarianceValidation: + """Test validation in _compute_dimension_variance.""" + + def test_empty_tensor(self, monitor): + """Test that empty tensor returns empty dict.""" + z = torch.empty(0, 64) + result = monitor._compute_dimension_variance(z) + assert result == {} + + def test_nan_tensor(self, monitor): + """Test that tensor with NaN returns empty dict.""" + z = torch.randn(64, 32) + z[0, 0] = float("nan") + result = monitor._compute_dimension_variance(z) + assert result == {} + + def test_inf_tensor(self, monitor): + """Test that tensor with Inf returns empty dict.""" + z = torch.randn(64, 32) + z[0, 0] = float("inf") + result = monitor._compute_dimension_variance(z) + assert result == {} + + def test_single_sample(self, monitor): + """Test that single sample returns empty dict (can't compute variance).""" + z = torch.randn(1, 32) + result = monitor._compute_dimension_variance(z) + assert result == {} + + def test_valid_tensor(self, monitor): + """Test that valid tensor returns metrics.""" + torch.manual_seed(42) + z = torch.randn(64, 32) + result = monitor._compute_dimension_variance(z) + assert "var_min" in result + assert "var_mean" in result + assert "var_max" in result + + +class TestForecastingSequences: + """Test forecasting with sequences of latents.""" + + @pytest.fixture + def forecast_config(self): + """Config for forecasting tests.""" + return { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": { + "enabled": True, + "tensor_source": "student", + "sample_size": 0, + "forecast_aggregation": "all", + }, + "singular_values": { + "enabled": True, + "tensor_source": "student", + "sample_size": 0, + "forecast_aggregation": "all", + }, + "dimension_variance": { + "enabled": True, + "tensor_source": "student", + "forecast_aggregation": "all", + }, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + + @pytest.fixture + def forecast_monitor(self, forecast_config): + """Create a monitor configured for forecasting.""" + return CollapseMonitor(forecast_config, torch.device("cpu")) + + def test_sequence_per_step_metrics(self, forecast_monitor): + """Test that per-step metrics are computed for sequences.""" + torch.manual_seed(42) + # Create 3 time steps of latents + latents = [torch.randn(32, 64) for _ in range(3)] + + metrics = forecast_monitor.compute_metrics(student_latent=latents) + + # Check per-step effective rank metrics + assert "collapse.student.effective_rank.step_0" in metrics + assert "collapse.student.effective_rank.step_1" in metrics + assert "collapse.student.effective_rank.step_2" in metrics + + # Check per-step variance metrics + assert "collapse.student.var_min.step_0" in metrics + assert "collapse.student.var_min.step_1" in metrics + assert "collapse.student.var_min.step_2" in metrics + + def test_sequence_aggregate_metrics(self, forecast_monitor): + """Test that aggregate metrics are computed for sequences.""" + torch.manual_seed(42) + latents = [torch.randn(32, 64) for _ in range(3)] + + metrics = forecast_monitor.compute_metrics(student_latent=latents) + + # Check aggregate effective rank metrics + assert "collapse.student.effective_rank.mean" in metrics + assert "collapse.student.effective_rank.min" in metrics + assert "collapse.student.effective_rank.max" in metrics + assert "collapse.student.effective_rank.degradation" in metrics + + # Verify aggregates are consistent + step_values = [ + metrics["collapse.student.effective_rank.step_0"], + metrics["collapse.student.effective_rank.step_1"], + metrics["collapse.student.effective_rank.step_2"], + ] + assert metrics["collapse.student.effective_rank.min"] == min(step_values) + assert metrics["collapse.student.effective_rank.max"] == max(step_values) + assert abs(metrics["collapse.student.effective_rank.mean"] - sum(step_values) / 3) < 1e-6 + + def test_degradation_metric(self, forecast_monitor): + """Test degradation metric (final/initial ratio).""" + torch.manual_seed(42) + # Create latents with controlled rank degradation + dim = 64 + # Step 0: Full rank random matrix + step_0 = torch.randn(128, dim) + # Step 1: Lower rank (rank ~32) + u1 = torch.randn(128, 32) + v1 = torch.randn(32, dim) + step_1 = u1 @ v1 + # Step 2: Even lower rank (rank ~8) + u2 = torch.randn(128, 8) + v2 = torch.randn(8, dim) + step_2 = u2 @ v2 + + latents = [step_0, step_1, step_2] + metrics = forecast_monitor.compute_metrics(student_latent=latents) + + # Degradation should be < 1 since rank decreases + degradation = metrics["collapse.student.effective_rank.degradation"] + assert degradation < 1.0, f"Expected degradation < 1.0, got {degradation}" + + # Verify degradation is step_2 / step_0 + expected = ( + metrics["collapse.student.effective_rank.step_2"] + / metrics["collapse.student.effective_rank.step_0"] + ) + assert abs(degradation - expected) < 1e-6 + + def test_aggregate_only_mode(self): + """Test forecast_aggregation='aggregate_only' mode.""" + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": { + "enabled": True, + "tensor_source": "student", + "forecast_aggregation": "aggregate_only", + }, + "singular_values": {"enabled": False}, + "dimension_variance": {"enabled": False}, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + + torch.manual_seed(42) + latents = [torch.randn(32, 64) for _ in range(3)] + metrics = monitor.compute_metrics(student_latent=latents) + + # Should NOT have per-step metrics + assert "collapse.student.effective_rank.step_0" not in metrics + + # Should have aggregate metrics + assert "collapse.student.effective_rank.mean" in metrics + assert "collapse.student.effective_rank.min" in metrics + + def test_per_step_only_mode(self): + """Test forecast_aggregation='per_step_only' mode.""" + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": { + "enabled": True, + "tensor_source": "student", + "forecast_aggregation": "per_step_only", + }, + "singular_values": {"enabled": False}, + "dimension_variance": {"enabled": False}, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + + torch.manual_seed(42) + latents = [torch.randn(32, 64) for _ in range(3)] + metrics = monitor.compute_metrics(student_latent=latents) + + # Should have per-step metrics + assert "collapse.student.effective_rank.step_0" in metrics + assert "collapse.student.effective_rank.step_1" in metrics + + # Should NOT have aggregate metrics + assert "collapse.student.effective_rank.mean" not in metrics + assert "collapse.student.effective_rank.degradation" not in metrics + + def test_empty_sequence(self, forecast_monitor): + """Test that empty sequence returns no metrics.""" + metrics = forecast_monitor.compute_metrics(student_latent=[]) + + # No effective rank metrics should be present + assert not any(key.startswith("collapse.student.effective_rank") for key in metrics) + + def test_single_step_sequence(self, forecast_monitor): + """Test sequence with single step (no degradation possible).""" + torch.manual_seed(42) + latents = [torch.randn(32, 64)] + + metrics = forecast_monitor.compute_metrics(student_latent=latents) + + # Should have step_0 + assert "collapse.student.effective_rank.step_0" in metrics + + # Should have aggregates (single value) + assert "collapse.student.effective_rank.mean" in metrics + + # Degradation should be 1.0 (same step) + assert metrics["collapse.student.effective_rank.degradation"] == 1.0 + + def test_mixed_single_and_sequence(self, forecast_monitor): + """Test with single tensor for student and sequence for teacher.""" + # Need to enable teacher monitoring + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": { + "enabled": True, + "tensor_source": "both", + "forecast_aggregation": "all", + }, + "singular_values": {"enabled": False}, + "dimension_variance": {"enabled": False}, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + + torch.manual_seed(42) + student = torch.randn(32, 64) # Single tensor + teacher = [torch.randn(32, 64) for _ in range(3)] # Sequence + + metrics = monitor.compute_metrics(student_latent=student, teacher_latent=teacher) + + # Student should have single metric + assert "collapse.student.effective_rank" in metrics + assert "collapse.student.effective_rank.step_0" not in metrics + + # Teacher should have sequence metrics + assert "collapse.teacher.effective_rank.step_0" in metrics + assert "collapse.teacher.effective_rank.mean" in metrics + + def test_3d_tensor_sequence(self, forecast_monitor): + """Test sequence of 3D tensors [B, N, D].""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + latents = [torch.randn(batch_size, num_patches, dim) for _ in range(3)] + + metrics = forecast_monitor.compute_metrics(student_latent=latents) + + # Should flatten each tensor and compute metrics + assert "collapse.student.effective_rank.step_0" in metrics + assert "collapse.student.effective_rank.mean" in metrics + + # Values should be reasonable (between 1 and dim) + for i in range(3): + value = metrics[f"collapse.student.effective_rank.step_{i}"] + assert 1 <= value <= dim + + +class TestSequenceSingularValues: + """Test singular value metrics for sequences.""" + + @pytest.fixture + def sv_monitor(self): + """Monitor configured for singular value tests.""" + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": {"enabled": False}, + "singular_values": { + "enabled": True, + "tensor_source": "student", + "sample_size": 0, + "forecast_aggregation": "all", + }, + "dimension_variance": {"enabled": False}, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + return CollapseMonitor(config, torch.device("cpu")) + + def test_sv_sequence_metrics(self, sv_monitor): + """Test singular value metrics for sequences.""" + torch.manual_seed(42) + latents = [torch.randn(64, 32) for _ in range(3)] + + metrics = sv_monitor.compute_metrics(student_latent=latents) + + # Per-step metrics + assert "collapse.student.sv_min.step_0" in metrics + assert "collapse.student.sv_max.step_0" in metrics + assert "collapse.student.sv_concentration.step_0" in metrics + + # Aggregate metrics + assert "collapse.student.sv_min.mean" in metrics + assert "collapse.student.sv_max.mean" in metrics + assert "collapse.student.sv_concentration.mean" in metrics + + +class TestSequenceVariance: + """Test dimension variance metrics for sequences.""" + + @pytest.fixture + def var_monitor(self): + """Monitor configured for variance tests.""" + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": {"enabled": False}, + "singular_values": {"enabled": False}, + "dimension_variance": { + "enabled": True, + "tensor_source": "student", + "forecast_aggregation": "all", + }, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + return CollapseMonitor(config, torch.device("cpu")) + + def test_variance_sequence_metrics(self, var_monitor): + """Test variance metrics for sequences.""" + torch.manual_seed(42) + latents = [torch.randn(64, 32) for _ in range(3)] + + metrics = var_monitor.compute_metrics(student_latent=latents) + + # Per-step metrics + assert "collapse.student.var_min.step_0" in metrics + assert "collapse.student.var_mean.step_0" in metrics + assert "collapse.student.var_max.step_0" in metrics + + # Aggregate metrics + assert "collapse.student.var_min.mean" in metrics + assert "collapse.student.var_mean.mean" in metrics + assert "collapse.student.var_max.mean" in metrics + + def test_variance_detects_collapse_over_time(self, var_monitor): + """Test that variance metrics can detect collapse over forecast steps.""" + torch.manual_seed(42) + dim = 32 + + # Step 0: Normal variance + step_0 = torch.randn(128, dim) + + # Step 1: Some dimensions start dying + step_1 = torch.randn(128, dim) + step_1[:, :8] *= 0.1 # Reduce variance in 8 dims + + # Step 2: More dimensions dead + step_2 = torch.randn(128, dim) + step_2[:, :16] *= 0.01 # Almost dead in 16 dims + + latents = [step_0, step_1, step_2] + metrics = var_monitor.compute_metrics(student_latent=latents) + + # var_min should decrease over steps (more dead dimensions) + assert metrics["collapse.student.var_min.step_0"] > metrics["collapse.student.var_min.step_1"] + assert metrics["collapse.student.var_min.step_1"] > metrics["collapse.student.var_min.step_2"]