diff --git a/config/config_jepa.yml b/config/config_jepa.yml index fc27da8c9..78c398e7d 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -26,7 +26,7 @@ ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 ae_global_dim_embed: 2048 -ae_global_num_blocks: 2 +ae_global_num_blocks: 0 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True @@ -37,7 +37,7 @@ ae_global_block_factor: 64 ae_global_mlp_hidden_factor: 2 ae_global_trailing_layer_norm: False -ae_aggregation_num_blocks: 8 +ae_aggregation_num_blocks: 12 ae_aggregation_num_heads: 32 ae_aggregation_dropout_rate: 0.1 ae_aggregation_with_qk_lnorm: True @@ -130,10 +130,33 @@ data_loading : # config for training training_config: - + # training_mode: "masking", "student_teacher", "latent_loss" training_mode: ["student_teacher"] + # Collapse monitoring for SSL training (JEPA/DINO/iBOT) + # Detects representation collapse via various metrics + collapse_monitoring: + enabled: true + compute_frequency: 100 # batches between metric computations + log_frequency: 100 # batches between metric logging + metrics: + effective_rank: + enabled: true + tensor_source: "both" # "student", "teacher", or "both" + sample_size: 2048 # max samples for SVD (0 = no sampling) + singular_values: + enabled: true + tensor_source: "both" + sample_size: 2048 + dimension_variance: + enabled: true + tensor_source: "both" # cheap to compute, good early indicator + prototype_entropy: + enabled: true # only applies to DINO + ema_beta: + enabled: true + num_mini_epochs: 32 samples_per_mini_epoch: 4096 shuffle: True @@ -148,10 +171,10 @@ training_config: learning_rate_scheduling : lr_start: 1e-6 - lr_max: 5e-5 + lr_max: 1e-4 lr_final_decay: 1e-6 lr_final: 0.0 - num_steps_warmup: 512 + num_steps_warmup: 4096 num_steps_cooldown: 512 policy_warmup: "cosine" policy_decay: "constant" @@ -159,14 +182,25 @@ training_config: parallel_scaling_policy: "sqrt" optimizer: - grad_clip: 1.0 - weight_decay: 0.1 + # Optimizer type: "adamw" (default) or "muon_adamw" (Muon for hidden weights, AdamW for embeddings/heads) + type: "muon_adamw" + grad_clip: 0.1 + weight_decay: 0.05 log_grad_norms: False adamw : # parameters are scaled by number of DDP workers beta1 : 0.975 beta2 : 0.9875 eps : 2e-08 + muon: + # Learning rate multiplier for Muon relative to base LR (muon_lr = base_lr * lr_multiplier) + lr_multiplier: 30.0 + # Momentum factor for Muon SGD + momentum: 0.95 + # Use Nesterov momentum + nesterov: true + # Weight decay for Muon parameters (uses optimizer.weight_decay if not specified) + weight_decay: 0.05 losses : { "student-teacher": { @@ -179,16 +213,20 @@ training_config: "num_blocks": 6, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 768, "dropout_rate": 0.1, target_source_correspondence: {0 : {0 : "subset"} }, + }, }, - }, - target_and_aux_calc: { "EMATeacher" : - { ema_ramp_up_ratio : 0.09, - ema_halflife_in_thousands: 1e-3, - model_param_overrides : { - training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} - }, - } - } + target_and_aux_calc: {FrozenTeacher: { + teacher_run_id: "yoqxf234", # "zosrc8ti", # Required + teacher_mini_epoch: -1}}, + # }, + # target_and_aux_calc: { "EMATeacher" : + # { ema_ramp_up_ratio : null, + # ema_halflife_in_thousands: 1e-1, + # model_param_overrides : { + # training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} + # }, + # } + # } } } diff --git a/config/config_jepa_finetuning.yml b/config/config_jepa_finetuning.yml index e9bf055a8..b12d9f98f 100644 --- a/config/config_jepa_finetuning.yml +++ b/config/config_jepa_finetuning.yml @@ -92,7 +92,8 @@ zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore ##################################### # streams_directory: "./config/streams/era5_1deg/" -streams_directory: "./config/streams/era5_synop_finetuning/" +# streams_directory: "./config/streams/era5_synop_finetuning/" +streams_directory: "./config/streams/era5_nppatms_finetuning/" streams: ??? general: @@ -139,8 +140,8 @@ training_config: samples_per_mini_epoch: 4096 shuffle: True - start_date: 1979-01-01T00:00 - end_date: 2022-12-31T00:00 + start_date: 2012-01-01T00:00 + end_date: 2021-12-31T00:00 time_window_step: 06:00:00 time_window_len: 06:00:00 @@ -271,7 +272,7 @@ validation_config: # write samples in normalized model space normalized_samples: False, # output streams to write; default all - streams: ["SurfaceCombined"], + streams: ["NPPATMS"], } # run validation before training starts (mainly for model development) diff --git a/config/default_config.yml b/config/default_config.yml index 613078ebe..1562b1844 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -11,7 +11,7 @@ embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 -ae_local_dim_embed: 512 #1024 +ae_local_dim_embed: 512 ae_local_num_blocks: 2 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 @@ -25,9 +25,9 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 512 #1024 #2048 +ae_global_dim_embed: 512 ae_global_num_blocks: 2 -ae_global_num_heads: 32 +ae_global_num_heads: 16 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True # TODO: switching to < 1 triggers triton-related issues. @@ -37,15 +37,15 @@ ae_global_block_factor: 64 ae_global_mlp_hidden_factor: 2 ae_global_trailing_layer_norm: False -ae_aggregation_num_blocks: 2 -ae_aggregation_num_heads: 32 +ae_aggregation_num_blocks: 8 +ae_aggregation_num_heads: 16 ae_aggregation_dropout_rate: 0.1 ae_aggregation_with_qk_lnorm: True ae_aggregation_att_dense_rate: 1.0 ae_aggregation_block_factor: 64 ae_aggregation_mlp_hidden_factor: 2 -decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning pred_adapter_kv: False pred_self_attention: True pred_dyadic_dims: False @@ -63,13 +63,15 @@ fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False -healpix_level: 5 +healpix_level: 4 with_mixed_precision: True with_flash_attention: True compile_model: False -with_fsdp: True +with_fsdp: False +ddp_find_unused_parameters: False attention_dtype: bf16 mixed_precision_dtype: bf16 mlp_norm_eps: 1e-5 @@ -82,18 +84,17 @@ latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True freeze_modules: "" -load_chkpt: {} norm_type: "LayerNorm" +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore ##################################### streams_directory: "./config/streams/era5_1deg/" +# streams_directory: "./config/streams/era5_nppatms_synop/" streams: ??? -# type of zarr_store -zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore - general: # mutable parameters @@ -107,7 +108,7 @@ general: # model_path, # run_path, # path_shared_ - + multiprocessing_method: "fork" desc: "" @@ -125,36 +126,55 @@ data_loading : num_workers: 12 rng_seed: ??? - repeat_data_in_mini_epoch : False - - # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with - # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. - # If this happens, you can disable the flag, but performance will drop on GH200. - memory_pinning: True # config for training training_config: - + # training_mode: "masking", "student_teacher", "latent_loss" - training_mode: ["masking"] + training_mode: ["masking", "student_teacher"] + + # Collapse monitoring for SSL training (JEPA/DINO/iBOT) + # Detects representation collapse via various metrics + collapse_monitoring: + enabled: true + compute_frequency: 100 # batches between metric computations + log_frequency: 100 # batches between metric logging + metrics: + effective_rank: + enabled: true + tensor_source: "both" # "student", "teacher", or "both" + sample_size: 2048 # max samples for SVD (0 = no sampling) + singular_values: + enabled: true + tensor_source: "both" + sample_size: 2048 + dimension_variance: + enabled: true + tensor_source: "both" # cheap to compute, good early indicator + prototype_entropy: + enabled: true # only applies to DINO + ema_beta: + enabled: true num_mini_epochs: 32 samples_per_mini_epoch: 4096 shuffle: True start_date: 1979-01-01T00:00 - end_date: 2022-12-31T00:00 + end_date: 2021-12-31T00:00 time_window_step: 06:00:00 time_window_len: 06:00:00 + + window_offset_prediction : 0 learning_rate_scheduling : lr_start: 1e-6 - lr_max: 5e-5 + lr_max: 1e-5 lr_final_decay: 1e-6 lr_final: 0.0 - num_steps_warmup: 512 + num_steps_warmup: 4096 num_steps_cooldown: 512 policy_warmup: "cosine" policy_decay: "constant" @@ -162,34 +182,93 @@ training_config: parallel_scaling_policy: "sqrt" optimizer: - grad_clip: 1.0 - weight_decay: 0.1 + # Optimizer type: "adamw" (default) or "muon_adamw" (Muon for hidden weights, AdamW for embeddings/heads) + type: "muon_adamw" + grad_clip: 0.5 + weight_decay: 0.05 log_grad_norms: False adamw : # parameters are scaled by number of DDP workers beta1 : 0.975 beta2 : 0.9875 eps : 2e-08 + muon: + # Learning rate multiplier for Muon relative to base LR (muon_lr = base_lr * lr_multiplier) + lr_multiplier: 30.0 + # Momentum factor for Muon SGD + momentum: 0.95 + # Use Nesterov momentum + nesterov: true + # Weight decay for Muon parameters (uses optimizer.weight_decay if not specified) + weight_decay: 0.05 losses : { - "physical": { - type: LossPhysical, - loss_fcts: { "mse": { }, }, + # "physical": { + # enabled: False, + # type: LossPhysical, + # weight: 0.1, + # loss_fcts: { + # "mse": { + # weight: 1.0, + # target_source_correspondence: { 0 : { 0 : "subset"} }, + # }, + # }, + # target_and_aux_calc: "Physical", + # }, + "student-teacher": { + enabled: True, + type: LossLatentSSLStudentTeacher, + weight: 1.0, + loss_fcts : { + "JEPA": { + 'weight': 4, "loss_extra_args": {}, "out_dim": 512, "head": transformer, + "num_blocks": 12, "num_heads": 16, "with_qk_lnorm": True, "intermediate_dim": 256, + "dropout_rate": 0.1, + target_source_correspondence: {0 : {0 : "subset"} }, + }, }, - } + # target_and_aux_calc: {FrozenTeacher: { + # teacher_run_id: "yoqxf234", # "zosrc8ti", # Required + # teacher_mini_epoch: -1}}, + # }, + target_and_aux_calc: { "EMATeacher" : + { ema_ramp_up_ratio : null, + ema_halflife_in_thousands: 1e-0, + model_param_overrides : { + training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} + }, + } + } + } + } model_input: { - "forecasting" : { - # masking strategy: "random", "healpix", "forecast" - masking_strategy: "forecast", + "random_easy" : { + # masking strategy: "random", "forecast" + masking_strategy: "healpix", + num_samples: 1, + num_steps_input: 1, + masking_strategy_config : { + diffusion_rn : False, + rate : 0.6, + hl_mask: 2, + rate_sampling: True }, - } + }, + } + + target_input: { + "random_easy_target" : { + masking_strategy: "healpix", + num_samples: 1, + masking_strategy_config : { rate : 0.66, hl_mask: 3, rate_sampling: True}, + }, + } forecast : - time_step: 06:00:00 - num_steps: 2 - offset: 1 - policy: "fixed" + time_step: 00:00:00 + num_steps: 0 + policy: null # validation config; full validation config is merge of training and validation config @@ -198,12 +277,12 @@ validation_config: samples_per_mini_epoch: 256 shuffle: False - start_date: 2023-10-01T00:00 - end_date: 2023-12-31T00:00 + start_date: 2022-10-01T00:00 + end_date: 2022-12-31T00:00 # whether to track the exponential moving average of weights for validation validate_with_ema: - enabled : True + enabled : False ema_ramp_up_ratio: 0.09 ema_halflife_in_thousands: 1e-3 @@ -215,12 +294,11 @@ validation_config: normalized_samples: False, # output streams to write; default all streams: null, - } + } # run validation before training starts (mainly for model development) - validate_before_training: False - - + validate_before_training: 8 + # test config; full test config is merge of validation and test config # test config is used by default when running inference @@ -246,7 +324,7 @@ wgtags: # issue number. # Expected values are lowercase strings with no spaces, just underscores: # Examples: "rollout_ablation_grid" - exp: null + exp: jepa # *** Experiment-specific tags *** # All extra tags (including lists, dictionaries, etc.) are treated # as strings by mlflow, so treat all extra tags as simple string key: value pairs. diff --git a/config/eval_nppatms.yml b/config/eval_nppatms.yml new file mode 100644 index 000000000..75e279688 --- /dev/null +++ b/config/eval_nppatms.yml @@ -0,0 +1,67 @@ +#optional: if commented out all is taken care of by the default settings +# NB. global options apply to all run_ids +#global_plotting_options: +# region: ["belgium", "global"] +# image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. +# dpi_val : 300 +# fps: 2 +# ERA5: +# marker_size: 2 +# scale_marker_size: 1 +# marker: "o" +# # alpha: 0.5 +# 2t: +# vmin: 250 +# vmax: 300 +# 10u: +# vmin: -40 +# vmax: 40 + +evaluation: + metrics : ["rmse", "mae"] + regions: ["global"] + summary_plots : true + ratio_plots : false + heat_maps : false + summary_dir: "./plots/" + plot_ensemble: "members" #supported: false, "std", "minmax", "members" + plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: false + score_cards: false + bar_plots: false + num_processes: 0 #options: int, "auto", 0 means no parallelism (default) + # baseline: "ar40mckx" + + +default_streams: + NPPATMS: + channels: ["obsvalue_rawbt_1", "obsvalue_rawbt_2", "obsvalue_rawbt_3", "obsvalue_rawbt_4", "obsvalue_rawbt_5", "obsvalue_rawbt_6", "obsvalue_rawbt_10", "obsvalue_rawbt_20"] #, "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + evaluation: + forecast_step: "all" + sample: [1, 2, 3, 4, 5, 6, 7] + ensemble: "all" #supported: "all", "mean", [0,1,2] + plotting: + sample: [1, 2, 3, 4, 5, 6, 7] + forecast_step: "all" #supported: "all", [1,2,3,...], "1-50" (equivalent of [1,2,3,...50]) + ensemble: "all" #supported: "all", "mean", [0,1,2] + plot_maps: true + plot_target: true + plot_histograms: true + plot_animations: false + +run_ids : + # diagnostic NPPATMS model, spm6zeor, continued training of pretrained model k8smwg67 + v9ntzbh2: + label: "diagnostic NPPATMS model spm6zeor, cont. k8smwg67" + results_base_dir : "./results/v9ntzbh2/" + # prognostic NPPATMS model + #i74cu321: + # label: "pretrained model i74cu321" + # results_base_dir : "./results/i74cu321/" + # here below we have one without --options test_config.output.normalized_samples=False + # us3wofcj: + # label: "pretrained model us3wofcj" + # results_base_dir : "./results/us3wofcj/" + #NEW: if "streams" is not specified, the default streams are used diff --git a/config/streams/era5_nppatms_finetuning/era5.yml b/config/streams/era5_nppatms_finetuning/era5.yml new file mode 100644 index 000000000..45d7ddf9c --- /dev/null +++ b/config/streams/era5_nppatms_finetuning/era5.yml @@ -0,0 +1,40 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + forcing: True + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 + diff --git a/config/streams/era5_nppatms_finetuning/nppatms.yml b/config/streams/era5_nppatms_finetuning/nppatms.yml new file mode 100644 index 000000000..a1ff0552b --- /dev/null +++ b/config/streams/era5_nppatms_finetuning/nppatms.yml @@ -0,0 +1,28 @@ +# obs_types +# 0 : polar orbiting satellites +# 1 : geostationay satellites +# 2 : conventional observations + +NPPATMS : + type : obs + stream_id : 1 + filenames : ['observations-ea-ofb-0001-2012-2023-npp-atms-radiances-v2.zarr'] + loss_weight : 1.0 + token_size : 32 + diagnostic: True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 2 + dim_embed : 128 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 128 + target_readout : + num_layers : 1 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/integration_tests/jepa1_test.py b/integration_tests/jepa1_test.py index f2959f3c9..ce87a877d 100644 --- a/integration_tests/jepa1_test.py +++ b/integration_tests/jepa1_test.py @@ -12,12 +12,11 @@ import os import shutil from pathlib import Path -import omegaconf -import pytest + import numpy as np +import pytest -from weathergen.evaluate.run_evaluation import evaluate_from_config -from weathergen.run_train import inference_from_args, train_with_args +from weathergen.run_train import main from weathergen.utils.metrics import get_train_metrics_path logger = logging.getLogger(__name__) @@ -48,14 +47,14 @@ def setup(test_run_id): @pytest.mark.parametrize("test_run_id", ["test_jepa1_" + commit_hash]) def test_train(setup, test_run_id): logger.info(f"test_train with run_id {test_run_id} {WEATHERGEN_HOME}") - - train_with_args( - [ f"--config={WEATHERGEN_HOME}/integration_tests/jepa1.yaml" ] - + [ + + main( + [ + "train", + f"--config={WEATHERGEN_HOME}/integration_tests/jepa1.yaml", "--run-id", test_run_id, - ], - f"{WEATHERGEN_HOME}/config/streams/streams_test/", + ] ) assert_missing_metrics_file(test_run_id) @@ -85,12 +84,26 @@ def assert_missing_metrics_file(run_id): def assert_nans_in_metrics_file(run_id): """Test that there are no NaNs in the metrics file.""" metrics = load_metrics(run_id) - loss_values_train = np.array([entry.get('LossLatentSSLStudentTeacher.loss_avg') for entry in metrics if entry.get("stage") == 'train']) - loss_values_val = np.array([entry.get('LossLatentSSLStudentTeacher.loss_avg') for entry in metrics if entry.get("stage") == 'val']) + loss_values_train = np.array( + [ + entry.get('LossLatentSSLStudentTeacher.loss_avg') + for entry in metrics if entry.get("stage") == 'train' + ] + ) + loss_values_val = np.array( + [ + entry.get('LossLatentSSLStudentTeacher.loss_avg') + for entry in metrics if entry.get("stage") == 'val' + ] + ) #remove nans if applicable - loss_values_train = np.array([float(value) if value != 'nan' else np.nan for value in loss_values_train]) - loss_values_val = np.array([float(value) if value != 'nan' else np.nan for value in loss_values_val]) + loss_values_train = np.array( + [float(value) if value != 'nan' else np.nan for value in loss_values_train] + ) + loss_values_val = np.array( + [float(value) if value != 'nan' else np.nan for value in loss_values_val] + ) assert not np.isnan(loss_values_train).any(), ( "NaN values found in training loss metrics!" diff --git a/integration_tests/small1_test.py b/integration_tests/small1_test.py index d3c6e4024..b4845d157 100644 --- a/integration_tests/small1_test.py +++ b/integration_tests/small1_test.py @@ -15,9 +15,9 @@ import omegaconf import pytest -from weathergen.evaluate.run_evaluation import evaluate_from_config -from weathergen.run_train import inference_from_args, train_with_args +from weathergen.evaluate.run_evaluation import evaluate_from_config +from weathergen.run_train import main from weathergen.utils.metrics import get_train_metrics_path logger = logging.getLogger(__name__) @@ -49,13 +49,13 @@ def setup(test_run_id): def test_train(setup, test_run_id): logger.info(f"test_train with run_id {test_run_id} {WEATHERGEN_HOME}") - train_with_args( - f"--config={WEATHERGEN_HOME}/integration_tests/small1.yaml".split() - + [ + main( + [ + "inference", + f"--config={WEATHERGEN_HOME}/integration_tests/small1.yaml", "--run-id", test_run_id, - ], - f"{WEATHERGEN_HOME}/config/streams/streams_test/", + ] ) infer_with_missing(test_run_id) @@ -68,9 +68,16 @@ def test_train(setup, test_run_id): def infer(run_id): logger.info("run inference") - inference_from_args( - ["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--mini-epoch", "0"] - + [ + main( + [ + "-start", + "2022-10-10", + "-end", + "2022-10-11", + "--samples", + "10", + "--mini-epoch", + "0", "--from-run-id", run_id, "--run-id", @@ -83,9 +90,16 @@ def infer(run_id): def infer_with_missing(run_id): logger.info("run inference") - inference_from_args( - ["-start", "2021-10-10", "-end", "2022-10-11", "--samples", "10", "--mini-epoch", "0"] - + [ + main( + [ + "-start", + "2021-10-10", + "-end", + "2022-10-11", + "--samples", + "10", + "--mini-epoch", + "0", "--from-run-id", run_id, "--run-id", diff --git a/integration_tests/small_multi_stream.yaml b/integration_tests/small_multi_stream.yaml index a3edcc69a..35a1e2767 100644 --- a/integration_tests/small_multi_stream.yaml +++ b/integration_tests/small_multi_stream.yaml @@ -186,6 +186,7 @@ training_config: time_step: 06:00:00 num_steps: 2 policy: "fixed" + offset: 1 # validation config; full validation config is merge of training and validation config diff --git a/integration_tests/small_multi_stream_test.py b/integration_tests/small_multi_stream_test.py index 36cf12f5e..211581e78 100644 --- a/integration_tests/small_multi_stream_test.py +++ b/integration_tests/small_multi_stream_test.py @@ -23,9 +23,9 @@ import omegaconf import pytest -from weathergen.evaluate.run_evaluation import evaluate_from_config -from weathergen.run_train import inference_from_args, train_with_args +from weathergen.evaluate.run_evaluation import evaluate_from_config +from weathergen.run_train import main from weathergen.utils.metrics import get_train_metrics_path logger = logging.getLogger(__name__) @@ -58,15 +58,15 @@ def test_train_multi_stream(setup, test_run_id): """Test training with multiple streams including gridded and observation data.""" logger.info(f"test_train_multi_stream with run_id {test_run_id} {WEATHERGEN_HOME}") - train_with_args( - f"--base-config={WEATHERGEN_HOME}/integration_tests/small_multi_stream.yaml".split() - + [ + main( + [ + "train", + f"--base-config={WEATHERGEN_HOME}/integration_tests/small_multi_stream.yaml", "--run-id", test_run_id, - ], - f"{WEATHERGEN_HOME}/integration_tests/streams_multi/", + ] ) - + infer_multi_stream(test_run_id) # evaluate_multi_stream_results(test_run_id) assert_metrics_file_exists(test_run_id) @@ -78,9 +78,17 @@ def test_train_multi_stream(setup, test_run_id): def infer_multi_stream(run_id): """Run inference for multi-stream model.""" logger.info("run multi-stream inference") - inference_from_args( - ["-start", "2021-10-10", "-end", "2022-10-11", "--samples", "10", "--mini-epoch", "0"] - + [ + main( + [ + "inference", + "-start", + "2021-10-10", + "-end", + "2022-10-11", + "--samples", + "10", + "--mini-epoch", + "0", "--from-run-id", run_id, "--run-id", diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 794f829c6..f0243a717 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -301,6 +301,26 @@ def _apply_fixes(config: Config) -> Config: eventually removed. """ config = _check_logging(config) + config = _check_datasets(config) + return config + + +def _check_datasets(config: Config) -> Config: + """ + Collect dataset paths under legacy keys. + """ + config = config.copy() + if config.get("data_paths") is None: # TODO remove this for next version + legacy_keys = [ + "data_path_anemoi", + "data_path_obs", + "data_path_eobs", + "data_path_fesom", + "data_path_icon", + ] + paths = [config.get(key) for key in legacy_keys] + config.data_paths = [path for path in paths if path is not None] + return config @@ -526,6 +546,8 @@ def _load_private_conf(private_home: Path | None = None) -> DictConfig: if "secrets" in private_cf: del private_cf["secrets"] + private_cf = _check_datasets(private_cf) # TODO: remove temp backward compatibility fix + assert isinstance(private_cf, DictConfig) return private_cf diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 44ec60a3b..e3cdfefa6 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -607,7 +607,7 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]: image_paths += names if image_paths: - image_paths=sorted(image_paths) + image_paths = sorted(image_paths) images = [Image.open(path) for path in image_paths] images[0].save( f"{map_output_dir}/animation_{self.run_id}_{tag}_{sa}_{self.stream}_{region}_{var}.gif", diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index ab322ab28..c5d4cf8d0 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -190,6 +190,7 @@ def __init__( "grad_amplitude": self.calc_spatial_variability, "psnr": self.calc_psnr, "seeps": self.calc_seeps, + "nse": self.calc_nse, } self.prob_metrics_dict = { "ssr": self.calc_ssr, @@ -1199,6 +1200,34 @@ def seeps(ground_truth, prediction, thr_light, thr_heavy, seeps_weights): return seeps_values + def calc_nse(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: + """ + Calculate Nash–Sutcliffe_model_efficiency_coefficient (NSE) + of forecast data vs reference data + Metrics broadly used in hydrology + Parameters + ---------- + p: xr.DataArray + Forecast data array + gt: xr.DataArray + Ground truth data array + Returns + ------- + xr.DataArray + Nash–Sutcliffe_model_efficiency_coefficient (NSE) + + """ + + obs_mean = gt.mean(dim=self._agg_dims) + + num = ((gt - p) ** 2).sum(dim=self._agg_dims) + + den = ((gt - obs_mean) ** 2).sum(dim=self._agg_dims) + + nse = 1 - num / den + + return nse + ### Probablistic scores def calc_spread(self, p: xr.DataArray, **kwargs) -> xr.DataArray: diff --git a/packages/readers_extra/src/weathergen/readers_extra/registry.py b/packages/readers_extra/src/weathergen/readers_extra/registry.py index 957a5a350..39953a25e 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/registry.py +++ b/packages/readers_extra/src/weathergen/readers_extra/registry.py @@ -1,36 +1,24 @@ -from collections.abc import Callable -from dataclasses import dataclass - -from weathergen.common.config import Config - - -@dataclass -class ReaderEntry: - data_path: str | None - constructor: Callable - - -def get_extra_reader(name: str, cf: Config) -> object | None: - """Get an extra reader by name.""" +def get_extra_reader(stream_type: str) -> object | None: + """Get an extra reader by stream_type name.""" # Uses lazy imports to avoid circular dependencies and to not load all the readers at start. # There is no sanity check on them, so they may fail at runtime during imports - match name: + match stream_type: case "iconart": from weathergen.readers_extra.data_reader_iconart import DataReaderIconArt - return ReaderEntry(cf.data_path_icon, DataReaderIconArt) + return DataReaderIconArt case "eobs": from weathergen.readers_extra.data_reader_eobs import DataReaderEObs - return ReaderEntry(cf.data_path_eobs, DataReaderEObs) + return DataReaderEObs case "iconesm": from weathergen.readers_extra.data_reader_icon_esm import DataReaderIconEsm - return ReaderEntry(cf.data_path_icon_esm, DataReaderIconEsm) + return DataReaderIconEsm case "cams": from weathergen.readers_extra.data_reader_cams import DataReaderCams - return ReaderEntry(cf.data_path_cams, DataReaderCams) + return DataReaderCams case _: return None diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index aa3a61f44..f84111541 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -9,7 +9,7 @@ from numpy.typing import NDArray from weathergen.datasets.batch import SampleMetaData -from weathergen.utils.train_logger import Stage +from weathergen.train.utils import Stage from weathergen.utils.utils import is_stream_diagnostic, is_stream_forcing logger = logging.getLogger(__name__) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 83d436fb9..5afc72d69 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -31,9 +31,8 @@ get_tokens_lens, ) from weathergen.readers_extra.registry import get_extra_reader -from weathergen.train.utils import get_batch_size_from_config +from weathergen.train.utils import TRAIN, Stage, get_batch_size_from_config from weathergen.utils.distributed import is_root -from weathergen.utils.train_logger import TRAIN, Stage type AnyDataReader = DataReaderBase | DataReaderAnemoi | DataReaderObs type StreamName = str @@ -151,40 +150,36 @@ def __init__( match stream_info["type"]: case "obs": dataset = DataReaderObs - datapath = cf.data_path_obs - # kwargs["end"] = end_date_padded # TODO: implement the padding case "anemoi": dataset = DataReaderAnemoi - datapath = cf.data_path_anemoi case "fesom": dataset = DataReaderFesom - datapath = cf.data_path_fesom case type_name: - reader_entry = get_extra_reader(type_name, cf) - if reader_entry is not None: - dataset = reader_entry.constructor - datapath = reader_entry.data_path - else: + dataset = get_extra_reader(type_name) + if dataset is None: msg = f"Unsupported stream type {stream_info['type']}" f"for stream name '{stream_info['name']}'." raise ValueError(msg) - datapath = pathlib.Path(datapath) fname = pathlib.Path(fname) # dont check if file exists since zarr stores might be directories if fname.exists(): # check if fname is a valid path to allow for simple overwriting filename = fname else: - filename = pathlib.Path(datapath) / fname + filenames = [pathlib.Path(path) / fname for path in cf.data_paths] - if not filename.exists(): # see above + if not any(filename.exists() for filename in filenames): # see above msg = ( f"Did not find input data for {stream_info['type']} " - f"stream '{stream_info['name']}': {filename}." + f"stream '{stream_info['name']}': {filenames}." ) raise FileNotFoundError(msg) + # The same dataset can exist on different locations in the filesystem, + # so we need to choose here. + filename = filenames[0] + ds_type = stream_info["type"] if is_root(): logger.info( @@ -533,12 +528,12 @@ def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, s # source data: iterate overall input steps input_data = [] - for idx in range(base_idx - num_steps_input_max, base_idx + 1): + for idx in range(base_idx - num_steps_input_max + 1, base_idx + 1): # TODO: check that we are not out of bounds when we go back in time rdata = collect_datasources(stream_ds, idx, "source", self.rng) - if rdata.is_empty() and self._stage == TRAIN: + if rdata.is_empty(): # and self._stage == TRAIN: # work around for https://github.com/pytorch/pytorch/issues/158719 # create non-empty mean data instead of empty tensor time_win = self.time_window_handler.window(idx) @@ -560,7 +555,7 @@ def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, s rdata = collect_datasources(stream_ds, step_forecast_dt, "target", self.rng) - if rdata.is_empty() and self._stage == TRAIN: + if rdata.is_empty(): # and self._stage == TRAIN: # work around for https://github.com/pytorch/pytorch/issues/158719 # create non-empty mean data instead of empty tensor time_win = self.time_window_handler.window(timestep_idx) @@ -631,7 +626,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): if "masking" in mode: source_select += ["network_input", "target_coords"] target_select += ["target_values"] - if "student_teacher" in mode or mode == "latent_loss" in mode: + if "student_teacher" in mode or "latent_loss" in mode: source_select += ["network_input"] target_select += ["network_input"] # remove duplicates diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 3d61767f0..6dfe71c89 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -66,6 +66,10 @@ def get_tokens_windows(self, stream_info, data, pad_tokens): tokens = [] for rdata in data: + # skip empty data + if rdata.is_empty(): + continue + # tokenize data idxs_cells, idxs_cells_lens = tok( readerdata_to_torch(rdata), token_size, hl, pad_tokens ) diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 99606bdce..0743a1af0 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -13,6 +13,7 @@ from flash_attn import flash_attn_func, flash_attn_varlen_func from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from weathergen.model.layers import LayerScale, StochasticDepth from weathergen.model.norms import AdaLayerNorm, RMSNorm @@ -31,6 +32,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiSelfAttentionHeadVarlen, self).__init__() @@ -66,6 +69,16 @@ def __init__( self.dtype = attention_dtype + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + assert with_flash, "Only flash attention supported at the moment" def forward(self, x, x_lens, ada_ln_aux=None): @@ -99,6 +112,14 @@ def forward(self, x, x_lens, ada_ln_aux=None): out = self.proj_out(outs.flatten(-2, -1)) + # Apply LayerScale before residual + if self.layer_scale is not None: + out = self.layer_scale(out) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + out = self.drop_path(out) + if self.with_residual: out = out + x_in @@ -119,6 +140,8 @@ def __init__( softcap=0.0, norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiSelfAttentionHeadVarlenFlex, self).__init__() @@ -149,6 +172,16 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + assert with_flash, "Only flash attention supported at the moment" def att(qs, ks, vs, x_mask): @@ -174,6 +207,15 @@ def forward(self, x, x_lens=None): outs = self.compiled_flex_attention(qs, ks, vs).transpose(1, 2).squeeze() out = self.dropout(self.proj_out(outs.flatten(-2, -1))) + + # Apply LayerScale before residual + if self.layer_scale is not None: + out = self.layer_scale(out) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + out = self.drop_path(out) + if self.with_residual: out = out + x_in @@ -197,6 +239,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -230,6 +274,17 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype + + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + assert with_flash, "Only flash attention supported." # define block mask @@ -256,6 +311,15 @@ def forward(self, x, ada_ln_aux=None): outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) out = self.proj_out(self.dropout(outs.flatten(-2, -1))) + + # Apply LayerScale before residual + if self.layer_scale is not None: + out = self.layer_scale(out) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + out = self.drop_path(out) + if self.with_residual: out = x_in + out @@ -278,6 +342,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiCrossAttentionHeadVarlen, self).__init__() @@ -318,6 +384,17 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype + + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed_q, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + assert with_flash, "Only flash attention supported at the moment" def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): @@ -355,6 +432,15 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): assert False outs = self.proj_out(outs.flatten(-2, -1)) + + # Apply LayerScale before residual + if self.layer_scale is not None: + outs = self.layer_scale(outs) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + outs = self.drop_path(outs) + if self.with_residual: outs = x_q_in + outs @@ -378,6 +464,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiCrossAttentionHeadVarlenSlicedQ, self).__init__() @@ -425,6 +513,17 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype + + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed_q, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + assert with_flash, "Only flash attention supported at the moment" def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): @@ -466,6 +565,15 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): ] outs = self.proj_out(torch.stack(outs).transpose(1, 0).flatten(-2, -1)) + + # Apply LayerScale before residual + if self.layer_scale is not None: + outs = self.layer_scale(outs) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + outs = self.drop_path(outs) + if self.with_residual: outs = x_q_in + outs.reshape(x_q_in.shape) @@ -487,6 +595,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiSelfAttentionHead, self).__init__() @@ -521,6 +631,17 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype + + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + if with_flash: self.att = torch.nn.functional.scaled_dot_product_attention else: @@ -546,6 +667,15 @@ def forward(self, x, ada_ln_aux=None): outs = flash_attn_func(qs, ks, vs, softcap=self.softcap, dropout_p=dropout_rate) out = self.proj_out(outs.flatten(-2, -1)) + + # Apply LayerScale before residual + if self.layer_scale is not None: + out = self.layer_scale(out) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + out = self.drop_path(out) + if self.with_residual: out = out + x_in @@ -566,6 +696,8 @@ def __init__( norm_type="LayerNorm", norm_eps=1e-5, attention_dtype=torch.bfloat16, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): super(MultiCrossAttentionHead, self).__init__() @@ -602,6 +734,17 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype + + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_embed_q, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + self.att = torch.nn.functional.scaled_dot_product_attention self.softmax = torch.nn.Softmax(dim=-1) @@ -624,6 +767,15 @@ def forward(self, x_q, x_kv): outs = self.att(qs, ks, vs).transpose(2, 1) outs = self.dropout(self.proj_out(outs.flatten(-2, -1))) + + # Apply LayerScale before residual + if self.layer_scale is not None: + outs = self.layer_scale(outs) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + outs = self.drop_path(outs) + if self.with_residual: outs = x_q_in + outs diff --git a/src/weathergen/model/ema.py b/src/weathergen/model/ema.py index 141947863..b42d756d4 100644 --- a/src/weathergen/model/ema.py +++ b/src/weathergen/model/ema.py @@ -30,6 +30,7 @@ def __init__( self.rampup_ratio = rampup_ratio self.ema_model = empty_model self.is_model_sharded = is_model_sharded + self.batch_size = 1 # Build a name → param map once self.src_params = dict(self.original_model.named_parameters()) @@ -55,16 +56,33 @@ def requires_grad_(self, flag: bool): for p in self.ema_model.parameters(): p.requires_grad = flag + def get_current_beta(self, cur_step: int) -> float: + """ + Get current EMA beta value for monitoring. + + The beta value determines how much the teacher model is updated towards + the student model at each step. Higher beta means slower teacher updates. + + Args: + cur_step: Current training step (typically istep * batch_size). + + Returns: + Current EMA beta value. + """ + halflife_steps = self.halflife_steps + if self.rampup_ratio is not None: + halflife_steps = min(halflife_steps, cur_step / self.rampup_ratio) + beta = 0.5 ** (self.batch_size / max(halflife_steps, 1e-6)) + return beta + @torch.no_grad() def update(self, cur_step, batch_size): # ensure model remains sharded if self.is_model_sharded: self.ema_model.reshard() # determine correct interpolation params - halflife_steps = self.halflife_steps - if self.rampup_ratio is not None: - halflife_steps = min(halflife_steps, cur_step / 1e3 * self.rampup_ratio) - beta = 0.5 ** (batch_size / max(halflife_steps * 1e3, 1e-6)) + self.batch_size = batch_size + beta = self.get_current_beta(cur_step) for name, p_ema in self.ema_model.named_parameters(): p_src = self.src_params.get(name, None) diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 47e059014..c51124b80 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -153,7 +153,6 @@ def assimilate_local_project_chunked(self, tokens, tokens_global, cell_lens, q_c # combined cell lens for all tokens in batch across all input steps zero_pad = torch.zeros(1, device=tokens.device, dtype=torch.int32) - # subdivision factor for required splitting clen = self.num_healpix_cells // (2 if self.cf.healpix_level <= 5 else 8) tokens_global_unmasked = [] posteriors = [] diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index de5328a93..8769b5636 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -138,7 +138,16 @@ def __init__(self, cf: Config) -> None: self.cf = cf self.ae_local_blocks = torch.nn.ModuleList() - for _ in range(self.cf.ae_local_num_blocks): + # Get LayerScale and StochasticDepth config + layer_scale_init = cf.get("layer_scale_init", None) + stochastic_depth_cfg = cf.get("stochastic_depth", {}) + max_drop_rate = stochastic_depth_cfg.get("ae_local", 0.0) if stochastic_depth_cfg else 0.0 + num_blocks = self.cf.ae_local_num_blocks + + for i in range(num_blocks): + # Linear scaling of drop rate with depth + drop_rate = max_drop_rate * (i / max(num_blocks - 1, 1)) if num_blocks > 1 else 0.0 + self.ae_local_blocks.append( MultiSelfAttentionHeadVarlen( self.cf.ae_local_dim_embed, @@ -149,6 +158,8 @@ def __init__(self, cf: Config) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) self.ae_local_blocks.append( @@ -159,6 +170,8 @@ def __init__(self, cf: Config) -> None: dropout_rate=self.cf.ae_local_dropout_rate, norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) @@ -181,6 +194,15 @@ def __init__(self, cf: Config) -> None: self.cf = cf self.ae_adapter = torch.nn.ModuleList() + # Get LayerScale and StochasticDepth config + layer_scale_init = cf.get("layer_scale_init", None) + stochastic_depth_cfg = cf.get("stochastic_depth", {}) + # Use ae_local rate for adapter (transition layer) + max_drop_rate = stochastic_depth_cfg.get("ae_local", 0.0) if stochastic_depth_cfg else 0.0 + ae_adapter_num_blocks = cf.get("ae_adapter_num_blocks", 2) + + # First block + drop_rate = 0.0 if ae_adapter_num_blocks <= 1 else 0.0 self.ae_adapter.append( MultiCrossAttentionHeadVarlenSlicedQ( self.cf.ae_global_dim_embed, @@ -195,11 +217,19 @@ def __init__(self, cf: Config) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) - ae_adapter_num_blocks = cf.get("ae_adapter_num_blocks", 2) - for _ in range(ae_adapter_num_blocks - 1): + for i in range(ae_adapter_num_blocks - 1): + # Linear scaling of drop rate with depth + drop_rate = ( + max_drop_rate * ((i + 1) / max(ae_adapter_num_blocks - 1, 1)) + if ae_adapter_num_blocks > 1 + else 0.0 + ) + self.ae_adapter.append( MLP( self.cf.ae_global_dim_embed, @@ -208,6 +238,8 @@ def __init__(self, cf: Config) -> None: dropout_rate=self.cf.ae_adapter_dropout_rate, norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) self.ae_adapter.append( @@ -224,6 +256,8 @@ def __init__(self, cf: Config) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) @@ -257,12 +291,23 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: self.ae_aggregation_blocks = torch.nn.ModuleList() + # Get LayerScale and StochasticDepth config + layer_scale_init = cf.get("layer_scale_init", None) + stochastic_depth_cfg = cf.get("stochastic_depth", {}) + max_drop_rate = ( + stochastic_depth_cfg.get("ae_aggregation", 0.0) if stochastic_depth_cfg else 0.0 + ) + num_blocks = self.cf.ae_aggregation_num_blocks + global_rate = int(1 / self.cf.ae_aggregation_att_dense_rate) - for i in range(self.cf.ae_aggregation_num_blocks): + for i in range(num_blocks): + # Linear scaling of drop rate with depth + drop_rate = max_drop_rate * (i / max(num_blocks - 1, 1)) if num_blocks > 1 else 0.0 + ## Alternate between local and global attention # as controlled by cf.ae_dense_local_att_dense_rate # Last block is always global attention - if i % global_rate == 0 or i + 1 == self.cf.ae_aggregation_num_blocks: + if i % global_rate == 0 or i + 1 == num_blocks: self.ae_aggregation_blocks.append( MultiSelfAttentionHeadVarlen( self.cf.ae_global_dim_embed, @@ -273,6 +318,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) else: @@ -289,6 +336,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) # MLP block @@ -301,6 +350,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: hidden_factor=self.cf.ae_aggregation_mlp_hidden_factor, norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) @@ -329,12 +380,21 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: self.ae_global_blocks = torch.nn.ModuleList() + # Get LayerScale and StochasticDepth config + layer_scale_init = cf.get("layer_scale_init", None) + stochastic_depth_cfg = cf.get("stochastic_depth", {}) + max_drop_rate = stochastic_depth_cfg.get("ae_global", 0.0) if stochastic_depth_cfg else 0.0 + num_blocks = self.cf.ae_global_num_blocks + global_rate = int(1 / self.cf.ae_global_att_dense_rate) - for i in range(self.cf.ae_global_num_blocks): + for i in range(num_blocks): + # Linear scaling of drop rate with depth + drop_rate = max_drop_rate * (i / max(num_blocks - 1, 1)) if num_blocks > 1 else 0.0 + ## Alternate between local and global attention # as controlled by cf.ae_global_att_dense_rate # Last block is always global attention - if i % global_rate == 0 or i + 1 == self.cf.ae_global_num_blocks: + if i % global_rate == 0 or i + 1 == num_blocks: self.ae_global_blocks.append( MultiSelfAttentionHead( self.cf.ae_global_dim_embed, @@ -345,6 +405,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) else: @@ -360,6 +422,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) # MLP block @@ -372,6 +436,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: hidden_factor=self.cf.ae_global_mlp_hidden_factor, norm_type=self.cf.norm_type, norm_eps=self.cf.mlp_norm_eps, + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) if self.cf.get("ae_global_trailing_layer_norm", False): @@ -400,9 +466,20 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.num_healpix_cells = num_healpix_cells self.fe_blocks = torch.nn.ModuleList() + # Get LayerScale and StochasticDepth config + layer_scale_init = cf.get("layer_scale_init", None) + stochastic_depth_cfg = cf.get("stochastic_depth", {}) + max_drop_rate = ( + stochastic_depth_cfg.get("forecasting", 0.0) if stochastic_depth_cfg else 0.0 + ) + num_blocks = self.cf.fe_num_blocks + global_rate = int(1 / self.cf.forecast_att_dense_rate) if mode_cfg.get("forecast", {}).get("policy") is not None: - for i in range(self.cf.fe_num_blocks): + for i in range(num_blocks): + # Linear scaling of drop rate with depth + drop_rate = max_drop_rate * (i / max(num_blocks - 1, 1)) if num_blocks > 1 else 0.0 + # Alternate between global and local attention if (i % global_rate == 0) or i + 1 == self.cf.ae_global_num_blocks: self.fe_blocks.append( @@ -416,6 +493,8 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) else: @@ -432,6 +511,8 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) # Add MLP block @@ -444,6 +525,8 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = norm_type=self.cf.norm_type, dim_aux=dim_aux, norm_eps=self.cf.mlp_norm_eps, + layer_scale_init=layer_scale_init, + stochastic_depth_rate=drop_rate, ) ) # Optionally, add LayerNorm after i-th layer diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 1f7b8df5d..7238d4799 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -14,6 +14,51 @@ from weathergen.model.norms import AdaLayerNorm, RMSNorm +class LayerScale(nn.Module): + """Per-channel learnable scaling, as in CaiT (Touvron et al., 2021). + + Applies a learned per-channel scaling factor to the input. When used before + residual connections, it allows the network to gradually incorporate new + layer contributions during training. + + Args: + dim: Number of channels/features to scale. + init_value: Initial value for the scaling factors. Use 1e-5 for LayerScale + or 0.0 for ReZero initialization. + """ + + def __init__(self, dim: int, init_value: float = 1e-5): + super().__init__() + self.gamma = nn.Parameter(init_value * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.gamma + + +class StochasticDepth(nn.Module): + """Stochastic Depth / DropPath regularization (Huang et al., 2016). + + Randomly drops entire residual paths during training. This acts as a form + of regularization and enables training deeper networks. + + Args: + drop_prob: Probability of dropping the path. 0.0 means no dropping. + """ + + def __init__(self, drop_prob: float = 0.0): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.training or self.drop_prob == 0.0: + return x + keep_prob = 1.0 - self.drop_prob + # Per-sample dropout (batch dimension) + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + mask = torch.empty(shape, dtype=x.dtype, device=x.device).bernoulli_(keep_prob) + return x * mask / keep_prob # Scale to maintain expected value + + class NamedLinear(torch.nn.Module): def __init__(self, name: str | None = None, **kwargs): super(NamedLinear, self).__init__() @@ -43,8 +88,16 @@ def __init__( dim_aux=None, norm_eps=1e-5, name: str | None = None, + layer_scale_init: float | None = None, + stochastic_depth_rate: float = 0.0, ): - """Constructor""" + """Constructor + + Args: + layer_scale_init: If not None, applies LayerScale with this init value. + Use 1e-5 for LayerScale, 0.0 for ReZero. + stochastic_depth_rate: Probability of dropping this block during training. + """ super(MLP, self).__init__() @@ -79,12 +132,30 @@ def __init__( self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) + # LayerScale: per-channel learned scaling before residual + self.layer_scale = ( + LayerScale(dim_out, layer_scale_init) if layer_scale_init is not None else None + ) + + # Stochastic Depth: randomly drop residual path during training + self.drop_path = ( + StochasticDepth(stochastic_depth_rate) if stochastic_depth_rate > 0.0 else None + ) + def forward(self, *args): x, x_in, aux = args[0], args[0], args[-1] for i, layer in enumerate(self.layers): x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) + # Apply LayerScale before residual + if self.layer_scale is not None: + x = self.layer_scale(x) + + # Apply Stochastic Depth before residual + if self.drop_path is not None: + x = self.drop_path(x) + if self.with_residual: if x.shape[-1] == x_in.shape[-1]: x = x_in + x diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 24962da13..47072b46d 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -21,7 +21,7 @@ ) from torch.distributed.tensor import distribute_tensor -from weathergen.common.config import Config, merge_configs +from weathergen.common.config import Config, get_path_model, merge_configs from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -29,12 +29,13 @@ MultiSelfAttentionHeadLocal, MultiSelfAttentionHeadVarlen, ) +from weathergen.common.config import get_path_model from weathergen.model.ema import EMAModel from weathergen.model.layers import MLP from weathergen.model.model import Model, ModelParams from weathergen.model.utils import apply_fct_to_blocks, freeze_weights from weathergen.train.target_and_aux_module_base import PhysicalTargetAndAux -from weathergen.train.target_and_aux_ssl_teacher import EMATeacher +from weathergen.train.target_and_aux_ssl_teacher import EMATeacher, FrozenTeacher from weathergen.utils.distributed import is_root from weathergen.utils.utils import get_dtype @@ -152,7 +153,7 @@ def init_model_and_shard( if is_root(): logger.info(f"Continuing run with id={run_id_contd} at mini_epoch {mini_epoch_contd}.") model = load_model(cf, model, device, run_id_contd, mini_epoch_contd) - elif cf.get("load_chkpt", None).get("run_id", None): + elif cf.get("load_chkpt", {}).get("run_id", None): run_id = cf.load_chkpt.run_id mini_epoch = cf.load_chkpt.get("mini_epoch", -1) if is_root(): @@ -179,7 +180,7 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): mini_epoch : The mini_epoch to load. Default (-1) is the latest mini_epoch """ - path_run = Path(cf.model_path) / run_id + path_run = get_path_model(run_id=run_id) mini_epoch_id = ( f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest" ) @@ -327,6 +328,9 @@ def get_target_aux_calculator( batch_size = cf.get("world_size_original", cf.get("world_size")) * batch_size_per_gpu target_aux = EMATeacher(model, ema_model, batch_size, cf.training_config) + elif target_and_aux_calc == "FrozenTeacher": + target_aux = FrozenTeacher.from_pretrained(cf, dataset, device, target_and_aux_calc_params) + else: raise NotImplementedError(f"{target_and_aux_calc} is not implemented") diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 65551fccb..e91501274 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -27,21 +27,63 @@ logger = logging.getLogger(__name__) +def train() -> None: + """Entry point for calling the training code from the command line.""" + main([cli.Stage.train] + sys.argv[1:]) + + +def train_continue() -> None: + """Entry point for calling train_continue from the command line.""" + main([cli.Stage.train_continue] + sys.argv[1:]) + + def inference(): - # By default, arguments from the command line are read. - inference_from_args(sys.argv[1:]) + """Entry point for calling the inference code from the command line.""" + main([cli.Stage.inference] + sys.argv[1:]) -def inference_from_args(argl: list[str]): +def main(argl: list[str]): + try: + argl = _fix_argl(argl) + except ValueError as e: + logger.error(str(e)) + + parser = cli.get_main_parser() + args = parser.parse_args(argl) + match args.stage: + case cli.Stage.train: + run_train(args) + case cli.Stage.train_continue: + run_continue(args) + case cli.Stage.inference: + run_inference(args) + case _: + logger.error("No stage was found.") + + +def _fix_argl(argl): # TODO remove this fix after grace period + """Ensure `stage` positional argument is in arglist.""" + if argl[0] not in cli.Stage: + try: + stage = os.environ.get("WEATHERGEN_STAGE") + except KeyError as e: + msg = ( + "`stage` postional argument and environment variable 'WEATHERGEN_STAGE' missing.", + "Provide either one or the other.", + ) + raise ValueError(msg) from e + + argl = [stage] + argl + + return argl + + +def run_inference(args): """ Inference function for WeatherGenerator model. - Entry point for calling the inference code from the command line. - When running integration tests, the arguments are directly provided. + Note: Additional configuration for inference (`test_config`) is set in the function. """ - parser = cli.get_inference_parser() - args = parser.parse_args(argl) - inference_overwrite = { "test_config": dict( shuffle=False, @@ -84,24 +126,12 @@ def inference_from_args(argl: list[str]): pdb.post_mortem(tb) -#################################################################################################### -def train_continue() -> None: +def run_continue(args): """ Function to continue training for WeatherGenerator model. - Entry point for calling train_continue from the command line. - Configurations are set in the function body. - Args: - from_run_id (str): Run/model id of pretrained WeatherGenerator model to - continue training. Defaults to None. Note: All model configurations are set in the function body. """ - train_continue_from_args(sys.argv[1:]) - - -def train_continue_from_args(argl: list[str]): - parser = cli.get_continue_parser() - args = parser.parse_args(argl) cli_overwrite = config.from_cli_arglist(args.options) cf = config.load_merge_configs( @@ -135,26 +165,12 @@ def train_continue_from_args(argl: list[str]): pdb.post_mortem(tb) -#################################################################################################### -def train() -> None: +def run_train(args): """ Training function for WeatherGenerator model. - Entry point for calling the training code from the command line. - Configurations are set in the function body. - Args: - run_id (str, optional): Run/model id of pretrained WeatherGenerator model to - continue training. Defaults to None. Note: All model configurations are set in the function body. """ - train_with_args(sys.argv[1:], None) - - -def train_with_args(argl: list[str], stream_dir: str | None): - """ - Training function for WeatherGenerator model.""" - parser = cli.get_train_parser() - args = parser.parse_args(argl) cli_overwrite = config.from_cli_arglist(args.options) @@ -191,20 +207,4 @@ def train_with_args(argl: list[str], stream_dir: str | None): if __name__ == "__main__": - try: - stage = os.environ.get("WEATHERGEN_STAGE") - except KeyError as e: - msg = "missing environment variable 'WEATHERGEN_STAGE'" - raise ValueError(msg) from e - - if stage == "train": - # Entry point for slurm script. - # Check whether --from-run-id passed as argument. - if any("--from-run-id" in arg for arg in sys.argv): - train_continue() - else: - train() - elif stage == "inference": - inference() - else: - logger.error("No stage was found.") + main(sys.argv[1:]) diff --git a/src/weathergen/train/collapse_monitor.py b/src/weathergen/train/collapse_monitor.py new file mode 100644 index 000000000..d77060ad8 --- /dev/null +++ b/src/weathergen/train/collapse_monitor.py @@ -0,0 +1,381 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +Collapse monitoring metrics for SSL training (JEPA/DINO). + +This module implements metrics to detect representation collapse during self-supervised learning: +- Effective Rank (RankMe): Entropy of normalized singular values +- Singular Value Spectrum: Top-k singular values and concentration ratio +- Per-Dimension Variance: Min/mean/max variance across embedding dimensions +- Prototype Entropy: Normalized entropy of DINO prototype assignments +- EMA Beta: Current teacher momentum value + +References: +- RankMe (ICML 2023): https://arxiv.org/abs/2210.02885 +- C-JEPA (NeurIPS 2024): https://arxiv.org/abs/2410.19560 +""" + +from __future__ import annotations + +import logging +from collections import defaultdict +from typing import Any + +import torch + +logger = logging.getLogger(__name__) + + +class CollapseMonitor: + """ + Monitor for detecting representation collapse during SSL training. + + Computes and caches various collapse indicators that can be logged + at configurable intervals to minimize computational overhead. + """ + + def __init__(self, config: dict[str, Any], device: torch.device) -> None: + """ + Initialize the collapse monitor. + + Args: + config: Configuration dictionary with collapse_monitoring settings. + device: Device to use for computations. + """ + self.device = device + self.enabled = config.get("enabled", False) + self.compute_frequency = config.get("compute_frequency", 100) + self.log_frequency = config.get("log_frequency", 100) + + # Metric configurations + metrics_config = config.get("metrics", {}) + + self.effective_rank_config = metrics_config.get("effective_rank", {}) + self.singular_values_config = metrics_config.get("singular_values", {}) + self.dimension_variance_config = metrics_config.get("dimension_variance", {}) + self.prototype_entropy_config = metrics_config.get("prototype_entropy", {}) + self.ema_beta_config = metrics_config.get("ema_beta", {}) + + # Cache for accumulating metrics between log intervals + self._metrics_cache: dict[str, list[float]] = defaultdict(list) + + def should_compute(self, step: int) -> bool: + """Check if metrics should be computed at this step.""" + return self.enabled and step % self.compute_frequency == 0 + + def should_log(self, step: int) -> bool: + """Check if metrics should be logged at this step.""" + return self.enabled and step % self.log_frequency == 0 + + def compute_metrics( + self, + student_latent: torch.Tensor | None = None, + teacher_latent: torch.Tensor | None = None, + prototype_probs: torch.Tensor | None = None, + ema_beta: float | None = None, + loss_type: str | None = None, + ) -> dict[str, float]: + """ + Compute all enabled collapse monitoring metrics. + + Args: + student_latent: Student model latent representations [B, N, D] or [B, D]. + teacher_latent: Teacher model latent representations [B, N, D] or [B, D]. + prototype_probs: Post-softmax prototype assignment probabilities [B, K] (DINO only). + ema_beta: Current EMA momentum value. + loss_type: Type of SSL loss ("JEPA" or "DINO"). + + Returns: + Dictionary of computed metrics. + """ + if not self.enabled: + return {} + + metrics: dict[str, float] = {} + + # Determine which tensors to monitor based on config + tensors_to_monitor: dict[str, torch.Tensor | None] = {} + + effective_rank_source = self.effective_rank_config.get("tensor_source", "both") + sv_source = self.singular_values_config.get("tensor_source", "both") + var_source = self.dimension_variance_config.get("tensor_source", "both") + + # Build tensor dict based on what's requested + if effective_rank_source in ("student", "both") or sv_source in ( + "student", + "both", + ) or var_source in ("student", "both"): + tensors_to_monitor["student"] = student_latent + + if effective_rank_source in ("teacher", "both") or sv_source in ( + "teacher", + "both", + ) or var_source in ("teacher", "both"): + tensors_to_monitor["teacher"] = teacher_latent + + # Compute effective rank + if self.effective_rank_config.get("enabled", True): + sample_size = self.effective_rank_config.get("sample_size", 2048) + for name, tensor in tensors_to_monitor.items(): + if tensor is not None: + source = self.effective_rank_config.get("tensor_source", "both") + if source == "both" or source == name: + eff_rank = self._compute_effective_rank(tensor, sample_size) + metrics[f"collapse.{name}.effective_rank"] = eff_rank + + # Compute singular value spectrum + if self.singular_values_config.get("enabled", True): + sample_size = self.singular_values_config.get("sample_size", 2048) + for name, tensor in tensors_to_monitor.items(): + if tensor is not None: + source = self.singular_values_config.get("tensor_source", "both") + if source == "both" or source == name: + sv_metrics = self._compute_singular_values(tensor, sample_size) + for key, value in sv_metrics.items(): + metrics[f"collapse.{name}.{key}"] = value + + # Compute per-dimension variance + if self.dimension_variance_config.get("enabled", True): + for name, tensor in tensors_to_monitor.items(): + if tensor is not None: + source = self.dimension_variance_config.get("tensor_source", "both") + if source == "both" or source == name: + var_metrics = self._compute_dimension_variance(tensor) + for key, value in var_metrics.items(): + metrics[f"collapse.{name}.{key}"] = value + + # Compute prototype entropy (DINO only) + if ( + self.prototype_entropy_config.get("enabled", True) + and prototype_probs is not None + and loss_type == "DINO" + ): + entropy = self._compute_prototype_entropy(prototype_probs) + metrics["collapse.dino.prototype_entropy"] = entropy + + # Log EMA beta + if self.ema_beta_config.get("enabled", True) and ema_beta is not None: + metrics["collapse.ema_beta"] = ema_beta + + # Cache metrics for averaging + for key, value in metrics.items(): + self._metrics_cache[key].append(value) + + return metrics + + def get_cached_metrics(self) -> dict[str, float]: + """ + Get averaged cached metrics and clear the cache. + + Returns: + Dictionary of averaged metrics since last call. + """ + averaged_metrics: dict[str, float] = {} + for key, values in self._metrics_cache.items(): + if values: + averaged_metrics[key] = sum(values) / len(values) + + self._metrics_cache.clear() + return averaged_metrics + + def _flatten_to_samples(self, z: torch.Tensor) -> torch.Tensor: + """ + Flatten patch dimension into sample dimension. + + Treats [B, N, D] as [B*N, D] where each patch is an independent sample. + This is consistent with C-JEPA/VICReg approach. + + Args: + z: Tensor of shape [B, N, D] or [B, D]. + + Returns: + Tensor of shape [B*N, D] or [B, D]. + """ + # Convert to float32 for SVD compatibility (bfloat16/float16 can fail) + if z.dtype in (torch.bfloat16, torch.float16): + z = z.float() + + if z.ndim == 3: + return z.reshape(-1, z.shape[-1]) + return z + + def _sample_rows(self, z: torch.Tensor, sample_size: int) -> torch.Tensor: + """ + Randomly sample rows to reduce SVD computation cost. + + Args: + z: Tensor of shape [N, D]. + sample_size: Maximum number of samples (0 = no sampling). + + Returns: + Sampled tensor of shape [min(N, sample_size), D]. + """ + if sample_size <= 0 or z.shape[0] <= sample_size: + return z + + indices = torch.randperm(z.shape[0], device=z.device)[:sample_size] + return z[indices] + + def _compute_effective_rank(self, z: torch.Tensor, sample_size: int = 2048) -> float: + """ + Compute effective rank via entropy of normalized singular values (RankMe). + + The effective rank measures how many dimensions are actually being used + in the representation. A low effective rank indicates collapse. + + Args: + z: Latent representations [B, N, D] or [B, D]. + sample_size: Maximum samples for SVD computation. + + Returns: + Effective rank (exp of entropy of normalized singular values). + """ + z = self._flatten_to_samples(z.detach()) + z = self._sample_rows(z, sample_size) + + # Validate tensor before SVD + if z.numel() == 0: + logger.warning("Empty tensor in effective rank computation") + return 0.0 + if torch.isnan(z).any() or torch.isinf(z).any(): + logger.warning("NaN/Inf values in tensor for effective rank computation") + return 0.0 + if z.shape[0] < 2 or z.shape[1] < 2: + logger.warning(f"Tensor too small for SVD: shape={z.shape}") + return 0.0 + + # Center the data + z_centered = z - z.mean(dim=0, keepdim=True) + + # Compute SVD + try: + _, s, _ = torch.linalg.svd(z_centered, full_matrices=False) + except RuntimeError as e: + # SVD can fail on degenerate matrices + logger.warning(f"SVD failed in effective rank computation: {e}, shape={z.shape}") + return 0.0 + + # Normalize singular values to get a probability distribution + s_normalized = s / (s.sum() + 1e-8) + + # Compute entropy + entropy = -torch.sum(s_normalized * torch.log(s_normalized + 1e-8)) + + # Effective rank is exp(entropy) + effective_rank = torch.exp(entropy) + + return effective_rank.item() + + def _compute_singular_values( + self, z: torch.Tensor, sample_size: int = 2048 + ) -> dict[str, float]: + """ + Compute singular value statistics and concentration ratio. + + The concentration ratio (top SV / sum of all SVs) indicates how much + variance is captured by the largest singular value. High concentration + suggests dimensional collapse. + + Args: + z: Latent representations [B, N, D] or [B, D]. + sample_size: Maximum samples for SVD computation. + + Returns: + Dictionary with sv_min, sv_max, sv_mean, and sv_concentration. + """ + z = self._flatten_to_samples(z.detach()) + z = self._sample_rows(z, sample_size) + + # Validate tensor before SVD + if z.numel() == 0: + logger.warning("Empty tensor in singular value computation") + return {} + if torch.isnan(z).any() or torch.isinf(z).any(): + logger.warning("NaN/Inf values in tensor for singular value computation") + return {} + if z.shape[0] < 2 or z.shape[1] < 2: + logger.warning(f"Tensor too small for SVD: shape={z.shape}") + return {} + + # Center the data + z_centered = z - z.mean(dim=0, keepdim=True) + + # Compute SVD + try: + _, s, _ = torch.linalg.svd(z_centered, full_matrices=False) + except RuntimeError as e: + logger.warning(f"SVD failed in singular value computation: {e}, shape={z.shape}") + return {} + + metrics: dict[str, float] = {} + + # Singular value statistics + metrics["sv_min"] = s.min().item() + metrics["sv_max"] = s.max().item() + metrics["sv_mean"] = s.mean().item() + + # Concentration ratio (top SV / sum) + s_sum = s.sum() + 1e-8 + metrics["sv_concentration"] = (s[0] / s_sum).item() + + return metrics + + def _compute_dimension_variance(self, z: torch.Tensor) -> dict[str, float]: + """ + Compute per-dimension variance statistics. + + Low minimum variance indicates "dead" dimensions that are not being used. + Large variance ratio (max/min) suggests imbalanced dimension usage. + + Args: + z: Latent representations [B, N, D] or [B, D]. + + Returns: + Dictionary with var_min, var_mean, var_max. + """ + z = self._flatten_to_samples(z.detach()) + + # Compute variance along sample dimension + var_per_dim = z.var(dim=0) + + return { + "var_min": var_per_dim.min().item(), + "var_mean": var_per_dim.mean().item(), + "var_max": var_per_dim.max().item(), + } + + def _compute_prototype_entropy(self, probs: torch.Tensor) -> float: + """ + Compute normalized entropy of DINO prototype assignments. + + Low entropy indicates collapse to few prototypes. Entropy is normalized + to [0, 1] range where 1 means uniform distribution. + + Args: + probs: Post-softmax prototype assignment probabilities [B, K]. + + Returns: + Normalized entropy in [0, 1]. + """ + probs = probs.detach() + + # Average across batch to get prototype usage distribution + avg_probs = probs.mean(dim=0) + + # Compute entropy + entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-8)) + + # Normalize by maximum possible entropy (uniform distribution) + num_prototypes = probs.shape[1] + max_entropy = torch.log(torch.tensor(float(num_prototypes), device=probs.device)) + + normalized_entropy = entropy / (max_entropy + 1e-8) + + return normalized_entropy.item() diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index 64e104827..61be2848b 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -18,7 +18,7 @@ import weathergen.train.loss_modules.loss_functions as loss_fns from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues -from weathergen.utils.train_logger import TRAIN, VAL, Stage +from weathergen.train.utils import TRAIN, VAL, Stage _logger = logging.getLogger(__name__) diff --git a/src/weathergen/train/lr_scheduler.py b/src/weathergen/train/lr_scheduler.py index e85cd1abf..db90eb65f 100644 --- a/src/weathergen/train/lr_scheduler.py +++ b/src/weathergen/train/lr_scheduler.py @@ -208,18 +208,18 @@ def step(self): if self.i_step > 0 else self.lr_max_scaled ) - for g in self.optimizer.param_groups: - g["lr"] = self.lr + self._set_param_group_lrs(self.lr) elif self.policy_decay == "constant" and phase_decay: cur_lr = self.lr self.lr = self.lr_max_scaled # make sure lr_max_scaled rate is used if warm-up end is not lr_max_scaled if cur_lr < self.lr: - for g in self.optimizer.param_groups: - g["lr"] = self.lr + self._set_param_group_lrs(self.lr) else: self.cur_scheduler.step() self.lr = self.cur_scheduler.get_last_lr()[0] + # Apply per-group LR multipliers after scheduler step + self._apply_lr_multipliers() # switch scheduler when learning rate regime completed if self.i_step == self.n_steps_warmup: @@ -237,6 +237,33 @@ def step(self): return self.lr + def _set_param_group_lrs(self, base_lr: float): + """ + Set learning rates for all parameter groups, applying per-group multipliers. + + For Muon+AdamW composite optimizers, Muon parameter groups have an lr_multiplier + that scales their learning rate relative to the base LR. + + Args: + base_lr: The base learning rate to set. + """ + for g in self.optimizer.param_groups: + lr_multiplier = g.get("lr_multiplier", 1.0) + g["lr"] = base_lr * lr_multiplier + + def _apply_lr_multipliers(self): + """ + Apply per-group LR multipliers after a scheduler step. + + The scheduler sets the same LR for all groups, so we need to scale + Muon groups by their lr_multiplier afterwards. + """ + for g in self.optimizer.param_groups: + if g.get("is_muon", False): + lr_multiplier = g.get("lr_multiplier", 1.0) + # Scale Muon groups relative to base LR + g["lr"] = self.lr * lr_multiplier + def get_lr(self): return self.lr diff --git a/src/weathergen/train/optimizer.py b/src/weathergen/train/optimizer.py new file mode 100644 index 000000000..65890b76a --- /dev/null +++ b/src/weathergen/train/optimizer.py @@ -0,0 +1,650 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +Optimizer module for WeatherGenerator. + +Provides support for: +- Standard AdamW optimizer +- Hybrid Muon+AdamW optimizer (Muon for 2D hidden weights, AdamW for embeddings/heads) + +The Muon optimizer uses orthogonalization of gradients for improved training dynamics +on transformer hidden layer weights. See: https://arxiv.org/abs/2407.01490 +""" + +import logging +from typing import Any + +import numpy as np +import torch +from torch.optim import Optimizer + +logger = logging.getLogger(__name__) + + +# Patterns identifying parameters that should use AdamW (not Muon) +# These include embeddings, prediction heads, and other 1D or special parameters +ADAMW_PATTERNS = [ + "embed_target_coords", + "embeds.", + "embed.", + "unembed", + "pred_heads", + "latent_heads", + "q_cells", + "bilin", + "class_token", + "register_token", + "norm", + "bias", +] + + +def classify_muon_params( + model: torch.nn.Module, +) -> tuple[list[torch.nn.Parameter], list[torch.nn.Parameter], list[str], list[str]]: + """ + Classify model parameters into Muon-eligible and AdamW-eligible groups. + + Muon is applied to 2D hidden layer weights (attention Q/K/V/O, MLP linear layers). + AdamW is applied to embeddings, output heads, 1D parameters, and biases. + + Args: + model: The model whose parameters to classify. + + Returns: + A tuple of (muon_params, adamw_params, muon_names, adamw_names). + """ + muon_params: list[torch.nn.Parameter] = [] + adamw_params: list[torch.nn.Parameter] = [] + muon_names: list[str] = [] + adamw_names: list[str] = [] + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + name_lower = name.lower() + + # 1D parameters (biases, layer norm weights) -> AdamW + if param.ndim < 2: + adamw_params.append(param) + adamw_names.append(name) + continue + + # Check if parameter matches any AdamW pattern + is_adamw = any(pattern in name_lower for pattern in ADAMW_PATTERNS) + + if is_adamw: + adamw_params.append(param) + adamw_names.append(name) + else: + # 2D hidden weights -> Muon + muon_params.append(param) + muon_names.append(name) + + return muon_params, adamw_params, muon_names, adamw_names + + +def _scale_adamw_betas( + beta1_base: float, + beta2_base: float, + eps_base: float, + batch_size_total: int, +) -> tuple[float, float, float]: + """ + Scale AdamW hyperparameters based on batch size following SDE scaling rules. + + See: https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + + Args: + beta1_base: Base beta1 value (target for batch_size_total=1). + beta2_base: Base beta2 value (target for batch_size_total=1). + eps_base: Base epsilon value. + batch_size_total: Total effective batch size across all ranks. + + Returns: + Tuple of (scaled_beta1, scaled_beta2, scaled_eps). + """ + kappa = batch_size_total + beta1 = max(0.5, 1.0 - kappa * (1.0 - beta1_base)) + beta2 = 1.0 - kappa * (1.0 - beta2_base) + eps = eps_base / np.sqrt(kappa) + return beta1, beta2, eps + + +def create_optimizer( + model: torch.nn.Module, + optimizer_cfg: Any, + lr_cfg: Any, + batch_size_total: int, +) -> Optimizer: + """ + Factory function to create the appropriate optimizer based on config. + + Args: + model: The model to optimize. + optimizer_cfg: Optimizer configuration containing type and hyperparameters. + lr_cfg: Learning rate configuration containing lr_start. + batch_size_total: Total effective batch size across all ranks. + + Returns: + The configured optimizer (AdamW or CompositeOptimizer). + """ + optimizer_type = optimizer_cfg.get("type", "adamw") + initial_lr = lr_cfg.lr_start + weight_decay = optimizer_cfg.weight_decay + + # Scale AdamW betas based on batch size + adamw_cfg = optimizer_cfg.adamw + beta1, beta2, eps = _scale_adamw_betas( + adamw_cfg.beta1, + adamw_cfg.beta2, + adamw_cfg.get("eps", 2e-08), + batch_size_total, + ) + + if optimizer_type == "adamw": + logger.info("Creating AdamW optimizer") + return torch.optim.AdamW( + model.parameters(), + lr=initial_lr, + weight_decay=weight_decay, + betas=(beta1, beta2), + eps=eps, + ) + + elif optimizer_type == "muon_adamw": + logger.info("Creating Muon+AdamW composite optimizer") + return _create_muon_adamw_optimizer( + model=model, + optimizer_cfg=optimizer_cfg, + initial_lr=initial_lr, + weight_decay=weight_decay, + adamw_betas=(beta1, beta2), + adamw_eps=eps, + batch_size_total=batch_size_total, + ) + + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + + +def _create_muon_adamw_optimizer( + model: torch.nn.Module, + optimizer_cfg: Any, + initial_lr: float, + weight_decay: float, + adamw_betas: tuple[float, float], + adamw_eps: float, + batch_size_total: int, +) -> "CompositeOptimizer": + """ + Create a Muon+AdamW composite optimizer. + + Args: + model: The model to optimize. + optimizer_cfg: Optimizer configuration. + initial_lr: Initial learning rate (for AdamW; Muon uses multiplied version). + weight_decay: Weight decay coefficient. + adamw_betas: Scaled (beta1, beta2) for AdamW. + adamw_eps: Scaled epsilon for AdamW. + batch_size_total: Total effective batch size. + + Returns: + CompositeOptimizer wrapping Muon and AdamW. + """ + muon_cfg = optimizer_cfg.get("muon", {}) + lr_multiplier = muon_cfg.get("lr_multiplier", 20.0) + muon_momentum = muon_cfg.get("momentum", 0.95) + muon_nesterov = muon_cfg.get("nesterov", True) + muon_weight_decay = muon_cfg.get("weight_decay", weight_decay) + + # Classify parameters + muon_params, adamw_params, muon_names, adamw_names = classify_muon_params(model) + + logger.info(f"Muon parameters ({len(muon_params)}): {muon_names[:5]}...") + logger.info(f"AdamW parameters ({len(adamw_params)}): {adamw_names[:5]}...") + + # Create parameter groups for AdamW + # Include both AdamW-only params and mark them appropriately + adamw_param_groups = [ + { + "params": adamw_params, + "lr": initial_lr, + "is_muon": False, + "lr_multiplier": 1.0, + } + ] + + # Create AdamW optimizer for embeddings/heads + adamw_optimizer = torch.optim.AdamW( + adamw_param_groups, + lr=initial_lr, + weight_decay=weight_decay, + betas=adamw_betas, + eps=adamw_eps, + ) + + # Create Muon optimizer for hidden weights + muon_lr = initial_lr * lr_multiplier + + # Parameter groups for Muon + muon_param_groups = [ + { + "params": muon_params, + "lr": muon_lr, + "is_muon": True, + "lr_multiplier": lr_multiplier, + } + ] + + # Try to use PyTorch's built-in Muon if available (PyTorch >= 2.9) + muon_optimizer = _create_muon_optimizer( + param_groups=muon_param_groups, + lr=muon_lr, + momentum=muon_momentum, + nesterov=muon_nesterov, + weight_decay=muon_weight_decay, + ) + + return CompositeOptimizer( + muon_optimizer=muon_optimizer, + adamw_optimizer=adamw_optimizer, + muon_lr_multiplier=lr_multiplier, + ) + + +def _create_muon_optimizer( + param_groups: list[dict], + lr: float, + momentum: float, + nesterov: bool, + weight_decay: float, +) -> Optimizer: + """ + Create a Muon optimizer, using PyTorch's built-in version if available. + + Falls back to custom implementation for older PyTorch versions. + """ + # Try PyTorch's built-in Muon (available in PyTorch >= 2.9) + if hasattr(torch.optim, "Muon"): + logger.info("Using torch.optim.Muon") + return torch.optim.Muon( + param_groups, + lr=lr, + momentum=momentum, + nesterov=nesterov, + weight_decay=weight_decay, + ) + else: + logger.warning( + "Using custom Muon implementation (torch.optim.Muon not available). " + "NOTE: This implementation does NOT support FSDP2. Use DDP or single-GPU training." + ) + return MuonCustom( + param_groups, + lr=lr, + momentum=momentum, + nesterov=nesterov, + weight_decay=weight_decay, + ) + + +class CompositeOptimizer(Optimizer): + """ + Composite optimizer that combines Muon and AdamW for different parameter groups. + + Muon is used for 2D hidden layer weights, AdamW for embeddings and heads. + This class wraps both optimizers and provides a unified interface. + + Inherits from Optimizer for compatibility with PyTorch LR schedulers. + """ + + def __init__( + self, + muon_optimizer: Optimizer, + adamw_optimizer: Optimizer, + muon_lr_multiplier: float = 20.0, + ): + """ + Initialize the composite optimizer. + + Args: + muon_optimizer: Optimizer for Muon-eligible parameters. + adamw_optimizer: Optimizer for AdamW-eligible parameters. + muon_lr_multiplier: LR multiplier for Muon relative to base LR. + """ + self.muon_optimizer = muon_optimizer + self.adamw_optimizer = adamw_optimizer + self.muon_lr_multiplier = muon_lr_multiplier + + # Manually initialize Optimizer base class attributes without calling __init__ + # This avoids the param_groups setup that would conflict with our combined groups + from collections import OrderedDict, defaultdict + + # Set defaults with betas for LR scheduler compatibility (OneCycleLR checks this) + # Use AdamW's betas since that's the more common scheduler interaction + adamw_betas = adamw_optimizer.defaults.get("betas", (0.9, 0.999)) + self.defaults = { + "betas": adamw_betas, + "momentum": muon_optimizer.defaults.get("momentum", 0.95), + } + self._optimizer_step_pre_hooks = OrderedDict() + self._optimizer_step_post_hooks = OrderedDict() + self._optimizer_state_dict_pre_hooks = OrderedDict() + self._optimizer_state_dict_post_hooks = OrderedDict() + self._optimizer_load_state_dict_pre_hooks = OrderedDict() + self._optimizer_load_state_dict_post_hooks = OrderedDict() + + # Ensure all param groups have betas for OneCycleLR compatibility + # OneCycleLR with cycle_momentum=True tries to modify betas on ALL groups + for group in muon_optimizer.param_groups: + if "betas" not in group: + group["betas"] = adamw_betas + + # Combined param_groups from both optimizers + self.param_groups = muon_optimizer.param_groups + adamw_optimizer.param_groups + + # State is a combined view (we override the property below) + self._state = defaultdict(dict) + + def step(self, closure=None): + """ + Perform a single optimization step. + + Args: + closure: A closure that reevaluates the model and returns the loss. + Optional for most optimizers. + + Returns: + Loss value if closure is provided, None otherwise. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + self.muon_optimizer.step() + self.adamw_optimizer.step() + + return loss + + def zero_grad(self, set_to_none: bool = True): + """ + Reset gradients of all optimized parameters. + + Args: + set_to_none: If True, set gradients to None instead of zero. + This can improve memory efficiency. + """ + self.muon_optimizer.zero_grad(set_to_none=set_to_none) + self.adamw_optimizer.zero_grad(set_to_none=set_to_none) + + def state_dict(self) -> dict: + """ + Return the state of both optimizers as a single dictionary. + + Returns: + Dictionary containing state from both Muon and AdamW optimizers. + """ + return { + "muon": self.muon_optimizer.state_dict(), + "adamw": self.adamw_optimizer.state_dict(), + "muon_lr_multiplier": self.muon_lr_multiplier, + "optimizer_type": "composite_muon_adamw", + } + + def load_state_dict(self, state_dict: dict): + """ + Load optimizer state from a dictionary. + + Args: + state_dict: Dictionary containing saved optimizer state. + """ + if ( + "optimizer_type" in state_dict + and state_dict["optimizer_type"] == "composite_muon_adamw" + ): + self.muon_optimizer.load_state_dict(state_dict["muon"]) + self.adamw_optimizer.load_state_dict(state_dict["adamw"]) + self.muon_lr_multiplier = state_dict.get("muon_lr_multiplier", self.muon_lr_multiplier) + else: + # Fallback: try to load as regular optimizer state + # This handles migration from pure AdamW checkpoints + logger.warning( + "Loading non-composite state dict into CompositeOptimizer. " + "This may not work correctly - optimizer state may be lost." + ) + + @property + def state(self) -> dict: + """ + Return combined state from both optimizers. + + This provides a unified view of optimizer state for checkpointing. + """ + combined_state = dict(self._state) + combined_state.update(self.muon_optimizer.state) + combined_state.update(self.adamw_optimizer.state) + return combined_state + + @state.setter + def state(self, value): + """Set state (needed for Optimizer base class compatibility).""" + self._state = value + + +def _zeropower_via_newtonschulz5(grad: torch.Tensor, steps: int) -> torch.Tensor: + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of grad. + + Uses quintic iteration with coefficients selected to maximize the slope at zero. + This produces something like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), + rather than exact UV^T, but this doesn't hurt model performance. + + Reference: https://github.com/KellerJordan/Muon + + Args: + grad: Gradient tensor (must be at least 2D). + steps: Number of Newton-Schulz iterations. + + Returns: + Orthogonalized gradient tensor. + """ + assert grad.ndim >= 2 + coef_a, coef_b, coef_c = (3.4445, -4.7750, 2.0315) + x = grad.bfloat16() + + # Transpose if more rows than columns (NS works better on wide matrices) + if grad.size(-2) > grad.size(-1): + x = x.mT + + # Normalize by spectral norm (approximated by Frobenius norm for stability) + x = x / (x.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Perform Newton-Schulz iterations with quintic coefficients + for _ in range(steps): + xxt = x @ x.mT + poly = coef_b * xxt + coef_c * xxt @ xxt + x = coef_a * x + poly @ x + + # Restore original orientation + if grad.size(-2) > grad.size(-1): + x = x.mT + + return x + + +def _muon_update( + grad: torch.Tensor, + momentum_buffer: torch.Tensor, + beta: float = 0.95, + ns_steps: int = 5, + nesterov: bool = True, +) -> torch.Tensor: + """ + Compute Muon update: momentum + orthogonalization + scaling. + + Args: + grad: Parameter gradient. + momentum_buffer: Momentum buffer (modified in-place). + beta: Momentum coefficient. + ns_steps: Number of Newton-Schulz iterations. + nesterov: Whether to use Nesterov momentum. + + Returns: + The update to apply to parameters. + """ + # Momentum accumulation using lerp for numerical stability + momentum_buffer.lerp_(grad, 1 - beta) + + # Compute update (Nesterov or standard momentum) + if nesterov: + update = grad.lerp(momentum_buffer, beta) + else: + update = momentum_buffer.clone() + + # Reshape for orthogonalization if needed (e.g., conv filters) + original_shape = update.shape + if update.ndim == 4: + update = update.view(len(update), -1) + + # Apply Newton-Schulz orthogonalization + update = _zeropower_via_newtonschulz5(update, steps=ns_steps) + + # Scale by sqrt(max(1, rows/cols)) to preserve gradient magnitude + update = update * max(1, update.size(-2) / update.size(-1)) ** 0.5 + + # Restore original shape and dtype + return update.to(grad.dtype).view(original_shape) + + +class MuonCustom(Optimizer): + """ + Custom Muon optimizer implementation based on Keller Jordan's reference. + + Muon (MomentUm Orthogonalized by Newton-schulz) internally runs standard SGD-momentum, + then performs an orthogonalization post-processing step where each 2D parameter's update + is replaced with the nearest orthogonal matrix via Newton-Schulz iteration. + + Reference: https://github.com/KellerJordan/Muon + https://kellerjordan.github.io/posts/muon/ + + Note: Muon should only be used for hidden weight layers. Embeddings, output heads, + biases, and layer norms should use AdamW. + + WARNING: This implementation does NOT support FSDP2 (Fully Sharded Data Parallel). + The Newton-Schulz orthogonalization requires the FULL gradient matrix, but FSDP2 + shards gradients across GPUs. Computing `X @ X.T` on a sharded gradient gives + mathematically incorrect results. For FSDP2 support, see the distributed version + in the reference implementation: https://github.com/KellerJordan/Muon/blob/master/muon.py + """ + + def __init__( + self, + params, + lr: float = 0.02, + momentum: float = 0.95, + nesterov: bool = True, + weight_decay: float = 0.0, + ns_steps: int = 5, + ): + """ + Initialize the Muon optimizer. + + Args: + params: Iterable of parameters to optimize or dicts defining param groups. + lr: Learning rate (in units of spectral norm per update). + momentum: Momentum factor (0.95 is typically good). + nesterov: Whether to use Nesterov momentum. + weight_decay: Decoupled weight decay (like AdamW). + ns_steps: Number of Newton-Schulz iterations for orthogonalization. + """ + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + nesterov=nesterov, + weight_decay=weight_decay, + ns_steps=ns_steps, + ) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """ + Perform a single optimization step. + + Args: + closure: A closure that reevaluates the model and returns the loss. + + Returns: + Loss value if closure is provided, None otherwise. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # Check for FSDP2 on first step (DTensor indicates sharded parameters) + if not hasattr(self, "_fsdp_checked"): + self._fsdp_checked = True + for group in self.param_groups: + for p in group["params"]: + # FSDP2 uses DTensor for sharding + is_fsdp2 = hasattr(p, "_local_tensor") or type(p).__name__ == "DTensor" + assert not is_fsdp2, ( + "MuonCustom does not support FSDP2 (Fully Sharded Data Parallel). " + "The Newton-Schulz orthogonalization requires full gradients, but " + "FSDP2 shards gradients across GPUs, leading to incorrect results. " + "Options: (1) Use DDP instead of FSDP, (2) Use AdamW optimizer, " + "(3) Use torch.optim.Muon (PyTorch >= 2.9) if it supports FSDP. " + "Reference FSDP impl: github.com/KellerJordan/Muon/blob/master/muon.py" + ) + + for group in self.param_groups: + momentum = group["momentum"] + nesterov = group["nesterov"] + lr = group["lr"] + weight_decay = group["weight_decay"] + ns_steps = group.get("ns_steps", 5) + + for p in group["params"]: + if p.grad is None: + continue + + # Initialize momentum buffer if needed + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + + # Compute Muon update + update = _muon_update( + p.grad, + state["momentum_buffer"], + beta=momentum, + ns_steps=ns_steps, + nesterov=nesterov, + ) + + # Apply decoupled weight decay FIRST (like AdamW) + if weight_decay != 0: + p.mul_(1 - lr * weight_decay) + + # Apply the orthogonalized update + p.add_(update.view(p.shape), alpha=-lr) + + return loss diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index cfb252f86..58dac7702 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -9,10 +9,13 @@ from __future__ import annotations -from typing import Any +import logging +from typing import TYPE_CHECKING import torch +import torch.nn as nn +from weathergen.model.engines import LatentPredictionHeadIdentity from weathergen.model.ssl_target_processing import ( DINOTargetProcessing, JEPATargetProcessing, @@ -20,51 +23,123 @@ ) from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase, TargetAuxOutput +if TYPE_CHECKING: + from omegaconf import DictConfig -class EMATeacher(TargetAndAuxModuleBase): - def __init__(self, model, ema_model, batch_size, training_cfg, **kwargs): - # One of the issues is that the teacher model may have a different architecture - # to the student, e.g. JEPA. So we need quite a flexible way to instantiate the - # the teacher. Because of the device sharding etc that requires quite a bit of - # massaging we assume that the teacher creates the EMA model correctly. However, - # note that you cannot assume that model.state_dict equals ema_model.state_dict - self.ema_model = ema_model - self.batch_size = batch_size + from weathergen.common.config import Config + +logger = logging.getLogger(__name__) + + +class EncoderTeacher(TargetAndAuxModuleBase): + """Abstract base class for SSL teachers that use an encoder to generate targets. + + This class provides the common functionality for teacher models in student-teacher + SSL training setups. Subclasses must implement `_forward_teacher()` to define + how the teacher model generates outputs. + + Attributes: + teacher_model: The teacher model used to generate target representations. + postprocess_targets: Dict of postprocessing modules for each loss type. + """ + + def __init__(self, teacher_model, training_cfg: DictConfig, **kwargs): + """Initialize the EncoderTeacher. - # is a dict of TargetProcessing classes as we may use several in parallel + Args: + teacher_model: The teacher model (can be EMA model wrapper or frozen model). + training_cfg: Training configuration containing loss specifications. + Must have `losses` attribute with at least one LossLatentSSLStudentTeacher. + **kwargs: Additional arguments passed to postprocessing setup. + + Raises: + ValueError: If training_cfg has no LossLatentSSLStudentTeacher losses. + """ + self.teacher_model = teacher_model + + # Parse SSL losses from config to set up target postprocessing + assert hasattr(training_cfg, "losses"), ( + f"EncoderTeacher requires training_cfg with 'losses' attribute, " + f"got {type(training_cfg).__name__}" + ) losses_cfg = [ v.loss_fcts for k, v in training_cfg.losses.items() if v.type == "LossLatentSSLStudentTeacher" ] + + if not losses_cfg: + raise ValueError( + "EncoderTeacher requires at least one 'LossLatentSSLStudentTeacher' loss " + "in training_config.losses. Found loss types: " + f"{[v.type for v in training_cfg.losses.values()]}" + ) + # TODO: support multiple LossLatentSSLStudentTeacher loss terms + if len(losses_cfg) > 1: + logger.warning( + f"Found {len(losses_cfg)} LossLatentSSLStudentTeacher losses, " + "but only the first one is used for target postprocessing." + ) + self.postprocess_targets = get_target_postprocessing(losses_cfg[0], training_cfg, **kwargs) - self.reset() + def _forward_teacher(self, model_params, batch): + """Execute forward pass on the teacher model. - def reset(self, batch_size=None): - self.ema_model.reset() - if batch_size is not None: - self.batch_size = batch_size + Subclasses must implement this method to define their specific forward behavior. - def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: - return + Args: + model_params: Model parameters for the forward pass. + batch: Input batch. - def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: - if self.ema_model.is_model_sharded: - self.ema_model.ema_model.reshard() - self.ema_model.update(istep, self.batch_size) + Returns: + Model output with get_latent_prediction() method. + + Raises: + NotImplementedError: If not implemented by subclass. + """ + raise NotImplementedError("Subclasses must implement _forward_teacher()") + + def compute(self, istep: int, batch, model_params, model) -> TargetAuxOutput: + """Compute target representations from the teacher model. - def compute(self, bidx, batch, model_params, model) -> tuple[Any, Any]: + Args: + istep: Training step index. + batch: Input batch with get_samples(), get_output_len(), get_output_idxs() methods. + model_params: Model parameters for the forward pass. + model: Student model (not used, but part of interface). + + Returns: + TargetAuxOutput containing latent targets and auxiliary outputs. + + Raises: + KeyError: If teacher model doesn't output a required loss type. + """ with torch.no_grad(): - outputs = self.ema_model.forward_eval(model_params, batch).get_latent_prediction(0) + model_output = self._forward_teacher(model_params, batch) + outputs = model_output.get_latent_prediction(0) + targets = {} for loss_name, target_module in self.postprocess_targets.items(): + if loss_name not in outputs: + available_keys = list(outputs.keys()) if hasattr(outputs, "keys") else "N/A" + raise KeyError( + f"Teacher model output missing key '{loss_name}'. " + f"Available keys: {available_keys}. " + f"Ensure teacher model has latent head for '{loss_name}'." + ) targets[loss_name] = target_module(outputs[loss_name]) - # collect target meta-information for selected samples - aux_outputs = [list(sample.meta_info.values())[0] for sample in batch.get_samples()] + # Collect target meta-information for selected samples + samples = batch.get_samples() + aux_outputs = [] + for sample in samples: + if sample.meta_info: + aux_outputs.append(list(sample.meta_info.values())[0]) + else: + aux_outputs.append(None) targets_out = TargetAuxOutput(batch.get_output_len(), batch.get_output_idxs()) targets_out.latent = targets @@ -72,16 +147,391 @@ def compute(self, bidx, batch, model_params, model) -> tuple[Any, Any]: return targets_out - def to_device(self, device) -> EMATeacher: + def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: + """Update state before backward pass. Default is no-op.""" + return + + def to_device(self, device) -> EncoderTeacher: + """Move postprocessors to the specified device. + + Args: + device: Target device. + + Returns: + Self for method chaining. + """ for _, module in self.postprocess_targets.items(): module.to(device) return self + def get_current_beta(self, cur_step: int) -> float: + beta = self.ema_model.get_current_beta(cur_step) + return beta + + +class EMATeacher(EncoderTeacher): + """Teacher using Exponential Moving Average of student weights. + + This teacher maintains an EMA of the student model's weights and uses it + to generate target representations for SSL training. + """ + + def __init__(self, model, ema_model, batch_size: int, training_cfg: DictConfig, **kwargs): + """Initialize the EMATeacher. + + Args: + model: The student model (used for reference, weights copied to EMA). + ema_model: The EMA model wrapper that maintains averaged weights. + Must have reset(), update(), forward_eval() methods. + batch_size: Global batch size for EMA update scheduling. Must be positive. + training_cfg: Training configuration with SSL loss specifications. + **kwargs: Additional arguments passed to parent. + + Note: + The teacher model may have a different architecture to the student, + e.g. for JEPA. The ema_model handles weight copying appropriately. + You cannot assume model.state_dict equals ema_model.state_dict. + + Raises: + ValueError: If batch_size is not positive. + AssertionError: If ema_model lacks required methods. + """ + if batch_size <= 0: + raise ValueError(f"batch_size must be positive, got {batch_size}") + + # Validate ema_model interface + assert hasattr(ema_model, "reset"), "ema_model must have reset() method" + assert hasattr(ema_model, "update"), "ema_model must have update() method" + assert hasattr(ema_model, "forward_eval"), "ema_model must have forward_eval() method" + + self.ema_model = ema_model + self.batch_size = batch_size + super().__init__(ema_model, training_cfg, **kwargs) + self.reset() + + def _forward_teacher(self, model_params, batch): + """Execute forward pass using EMA model's forward_eval method.""" + return self.ema_model.forward_eval(model_params, batch) + + def reset(self, batch_size=None): + """Reset EMA model weights to match current student weights. + + Args: + batch_size: Optional new batch size to use for EMA updates. + """ + self.ema_model.reset() + if batch_size is not None: + self.batch_size = batch_size + + def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: + """Update EMA weights after optimizer step. + + Args: + istep: Current training step. + batch: Current batch (unused). + model: Student model (unused, EMA model tracks it internally). + **kwargs: Additional arguments (unused). + """ + if self.ema_model.is_model_sharded: + self.ema_model.ema_model.reshard() + self.ema_model.update(istep, self.batch_size) + -def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): +class FrozenTeacher(EncoderTeacher): + """Teacher loaded from a pre-trained checkpoint with frozen weights. + + This teacher uses a model loaded from a previous training run. The weights + are frozen and never updated during training. This is useful for distillation + from a pre-trained model as described in arXiv:2509.24317. + + The teacher model may have been pre-trained with any method (forecasting, MAE, etc.) + and doesn't need to have SSL latent heads. Identity heads are added automatically + for any SSL losses the student needs. + + Note: + This class intentionally does NOT call super().__init__() because: + 1. It sets up identity postprocessing (JEPATargetProcessing) for ALL losses, + regardless of what the student config specifies for DINO/iBOT + 2. The parent class would try to parse the teacher's training config for SSL losses, + but the teacher may have been trained without SSL (e.g., forecasting only) + + Warning: + This class modifies the teacher_model in-place by adding latent_heads if missing. + """ + + def __init__( + self, + teacher_model: nn.Module, + training_cfg: DictConfig | None, + teacher_model_params=None, + **kwargs, + ): + """Initialize the FrozenTeacher. + + Args: + teacher_model: Pre-trained model to use as teacher. Will be modified in-place + to add identity latent heads if they don't exist. + training_cfg: Current training configuration containing the student's SSL losses. + Used to determine which identity heads to add to the teacher. + If None, defaults to adding a JEPA head. + teacher_model_params: Model parameters matching the teacher's architecture + (positional embeddings, q_cells, etc.). If None, will use the student's + model_params which may cause dimension mismatch if architectures differ. + **kwargs: Additional arguments (unused, for interface compatibility). + """ + # Note: We intentionally don't call super().__init__() - see class docstring + self.teacher_model = teacher_model + self.teacher_model_params = teacher_model_params + + # Get required SSL loss names from current training config + required_heads = self._get_required_ssl_heads(training_cfg) + assert len(required_heads) > 0, "No SSL heads required - this should never happen" + + # Add identity heads to teacher if it doesn't have them (modifies model in-place) + self._ensure_identity_heads(teacher_model, required_heads) + + # Set up identity postprocessing for all SSL losses + # FrozenTeacher always uses identity (JEPATargetProcessing) regardless of loss type + self.postprocess_targets = {name: JEPATargetProcessing() for name in required_heads} + + # Freeze all parameters + for param in self.teacher_model.parameters(): + param.requires_grad = False + + # Set to eval mode permanently (affects BatchNorm, Dropout, etc.) + self.teacher_model.eval() + + def _get_required_ssl_heads(self, training_cfg: DictConfig | None) -> set[str]: + """Extract SSL loss names from training config. + + Args: + training_cfg: Training configuration containing losses specification. + If None, defaults to {"JEPA"}. + + Returns: + Set of SSL loss names (e.g., {"JEPA", "DINO"}). Never empty. + """ + if training_cfg is None: + logger.debug("FrozenTeacher: No training_cfg provided, defaulting to JEPA head") + return {"JEPA"} + + if not hasattr(training_cfg, "losses"): + logger.warning( + "FrozenTeacher: training_cfg has no 'losses' attribute, defaulting to JEPA head" + ) + return {"JEPA"} + + required_heads = set() + for loss_name, loss_cfg in training_cfg.losses.items(): + if not hasattr(loss_cfg, "type"): + continue + if loss_cfg.type == "LossLatentSSLStudentTeacher": + if hasattr(loss_cfg, "loss_fcts"): + required_heads.update(loss_cfg.loss_fcts.keys()) + else: + logger.warning( + f"FrozenTeacher: Loss '{loss_name}' has type LossLatentSSLStudentTeacher " + "but no loss_fcts, skipping" + ) + + if not required_heads: + logger.debug( + "FrozenTeacher: No LossLatentSSLStudentTeacher losses found in config, " + "defaulting to JEPA head" + ) + return {"JEPA"} + + logger.debug(f"FrozenTeacher: Required SSL heads from config: {required_heads}") + return required_heads + + def _ensure_identity_heads(self, teacher_model: nn.Module, required_heads: set[str]) -> None: + """Add identity latent heads to teacher model if they don't exist. + + The teacher may have been pre-trained without SSL losses (e.g., forecasting). + We add identity heads so that `get_latent_prediction()` returns the raw + encoder representations (specifically, patch_tokens from LatentState) for + the student's SSL losses. + + Warning: + This method modifies teacher_model IN-PLACE by adding to its latent_heads. + + Args: + teacher_model: The teacher model to modify. Will have latent_heads added/modified. + required_heads: Set of head names that must exist (e.g., {"JEPA", "DINO"}). + """ + # Ensure latent_heads ModuleDict exists + if not hasattr(teacher_model, "latent_heads") or teacher_model.latent_heads is None: + logger.info("FrozenTeacher: Teacher model has no latent_heads, creating ModuleDict") + teacher_model.latent_heads = nn.ModuleDict() + + # Add missing identity heads + for head_name in sorted(required_heads): # sorted for deterministic logging + if head_name not in teacher_model.latent_heads: + logger.info( + f"FrozenTeacher: Adding identity head '{head_name}' to teacher model " + f"(teacher was likely pre-trained without SSL losses)" + ) + teacher_model.latent_heads[head_name] = LatentPredictionHeadIdentity() + + @classmethod + def from_pretrained(cls, cf: Config, dataset, device, params: dict) -> FrozenTeacher: + """Create a FrozenTeacher from a pre-trained checkpoint. + + This factory method: + 1. Loads the teacher's config from the checkpoint + 2. Creates a model with the teacher's architecture + 3. Loads the pre-trained weights + 4. Creates ModelParams matching the teacher's architecture + 5. Returns a FrozenTeacher instance + + Args: + cf: Current training configuration. Used for: + - model_path: Where to find saved models + - training_config: To determine which SSL heads are needed + dataset: Dataset for model creation (provides input/output dimensions). + device: Target device (e.g., "cuda:0", "cpu"). + params: FrozenTeacher parameters from config, including: + - teacher_run_id (required): 8-character run ID of the pre-trained teacher. + - teacher_mini_epoch (optional): Mini-epoch to load. Default -1 (latest). + + Returns: + FrozenTeacher instance with loaded and frozen weights. + + Raises: + ValueError: If teacher_run_id is not provided or invalid. + FileNotFoundError: If checkpoint doesn't exist (from load_run_config/load_model). + """ + # Lazy imports to avoid circular dependency with model_interface + from weathergen.common.config import load_run_config, merge_configs + from weathergen.model.model import ModelParams + from weathergen.model.model_interface import get_model, load_model + from weathergen.utils.distributed import is_root + + teacher_run_id = params.get("teacher_run_id") + teacher_mini_epoch = params.get("teacher_mini_epoch", -1) + + # Validate teacher_run_id + if teacher_run_id is None: + raise ValueError( + "FrozenTeacher requires 'teacher_run_id' in config. " + "Example config:\n" + " target_and_aux_calc:\n" + " FrozenTeacher:\n" + " teacher_run_id: 'a1b2c3d4'" + ) + + if not isinstance(teacher_run_id, str) or len(teacher_run_id) == 0: + raise ValueError( + f"teacher_run_id must be a non-empty string, got {type(teacher_run_id).__name__}: " + f"{teacher_run_id!r}" + ) + + if is_root(): + logger.info( + f"Loading FrozenTeacher from run_id={teacher_run_id}, " + f"mini_epoch={teacher_mini_epoch}" + ) + + # Load teacher's config (contains full architecture) + model_path = cf.get("model_path") + assert model_path is not None, "cf.model_path is required to load FrozenTeacher checkpoint" + + teacher_config = load_run_config(teacher_run_id, teacher_mini_epoch, model_path) + + # Disable FSDP/DDP for frozen teacher - it's loaded as a simple non-sharded model + # This avoids complications with distributed training for the teacher + teacher_config = merge_configs(teacher_config, {"with_ddp": False, "with_fsdp": False}) + + # Create model with teacher's architecture + teacher_model = get_model(teacher_config, "student", dataset, {}) + + # Load weights from checkpoint + teacher_model = load_model( + teacher_config, teacher_model, device, teacher_run_id, teacher_mini_epoch + ) + + # Create model params matching teacher's architecture + # This includes positional embeddings, q_cells, etc. that depend on architecture + teacher_model_params = ModelParams(teacher_config).create(teacher_config) + teacher_model_params = teacher_model_params.to(device) + + if is_root(): + num_params = sum(p.numel() for p in teacher_model.parameters()) + logger.info(f"FrozenTeacher loaded with {num_params:,} parameters") + + # Pass current training config so FrozenTeacher knows which SSL heads to add + return cls( + teacher_model, + training_cfg=cf.training_config, + teacher_model_params=teacher_model_params, + ) + + def _forward_teacher(self, model_params, batch): + """Execute forward pass on the frozen teacher model. + + Uses the teacher's own model_params instead of the student's to ensure + dimension compatibility. + """ + # Use teacher's model params if available, otherwise fall back to passed-in params + params_to_use = ( + self.teacher_model_params if self.teacher_model_params is not None else model_params + ) + return self.teacher_model(params_to_use, batch) + + def reset(self, batch_size=None): + """No-op: frozen teacher weights don't change.""" + pass + + def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: + """No-op: frozen teacher weights don't change.""" + pass + + +def get_target_postprocessing( + target_losses: dict[str, DictConfig], training_cfg: DictConfig, **kwargs +) -> dict[str, nn.Module]: + """Create postprocessing modules for each SSL loss type. + + This function creates the appropriate postprocessing module for each SSL loss + based on its configuration. The postprocessing is applied to teacher outputs + before computing the student-teacher loss. + + - JEPA: Identity (no postprocessing) + - DINO: Centering and temperature sharpening + - iBOT: Patch-level centering and temperature sharpening + + Args: + target_losses: Dict of loss configurations keyed by loss name (e.g., "JEPA", "DINO"). + Each value should have the required config keys for that loss type. + training_cfg: Training configuration (currently unused, reserved for future use). + **kwargs: Additional arguments (currently unused). + + Returns: + Dict mapping loss names to their postprocessing nn.Module instances. + + Raises: + KeyError: If a loss config is missing required keys (e.g., out_dim for DINO). + + Example: + >>> target_losses = {"JEPA": {"head": "identity"}, "DINO": {"out_dim": 256, ...}} + >>> postprocessors = get_target_postprocessing(target_losses, training_cfg) + >>> postprocessors["JEPA"](teacher_output) # Identity transform + """ return_dict = {} for loss_name, conf in target_losses.items(): if loss_name == "iBOT": + # Validate required keys + required_keys = [ + "out_dim", + "center_momentum", + "loss_extra_args", + "teacher_temp", + "teacher_style", + ] + missing = [k for k in required_keys if k not in conf] + if missing: + raise KeyError(f"iBOT loss config missing required keys: {missing}") + return_dict[loss_name] = iBOTPatchTargetProcessing( patch_out_dim=conf["out_dim"], center_momentum=conf["center_momentum"], @@ -90,6 +540,12 @@ def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): teacher_style=conf["teacher_style"], ) elif loss_name == "DINO": + # Validate required keys + required_keys = ["out_dim", "center_momentum", "loss_extra_args", "teacher_style"] + missing = [k for k in required_keys if k not in conf] + if missing: + raise KeyError(f"DINO loss config missing required keys: {missing}") + return_dict[loss_name] = DINOTargetProcessing( out_dim=conf["out_dim"], center_momentum=conf["center_momentum"], @@ -99,6 +555,8 @@ def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): elif loss_name == "JEPA": return_dict[loss_name] = JEPATargetProcessing() else: - # We skip losses that are not handled by the EMATeacher + # Skip losses that are not handled by the EncoderTeacher + logger.debug(f"get_target_postprocessing: Skipping unknown loss type '{loss_name}'") continue + return return_dict diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index e949dc1cc..6fbdd0487 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -21,7 +21,7 @@ from torch.distributed.tensor import DTensor import weathergen.common.config as config -from weathergen.common.config import Config, merge_configs +from weathergen.common.config import Config from weathergen.datasets.multi_stream_data_sampler import MultiStreamDataSampler from weathergen.model.ema import EMAModel from weathergen.model.model_interface import ( @@ -29,22 +29,32 @@ init_model_and_shard, ) from weathergen.model.utils import apply_fct_to_blocks, set_to_eval +from weathergen.train.collapse_monitor import CollapseMonitor from weathergen.train.loss_calculator import LossCalculator from weathergen.train.lr_scheduler import LearningRateScheduler +from weathergen.train.optimizer import CompositeOptimizer, create_optimizer +from weathergen.train.target_and_aux_ssl_teacher import EMATeacher from weathergen.train.trainer_base import TrainerBase from weathergen.train.utils import ( + TRAIN, + VAL, + Stage, + cfg_keys_to_filter, extract_batch_metadata, filter_config_by_enabled, + get_active_stage_config, get_batch_size_from_config, get_target_idxs_from_cfg, ) from weathergen.utils.distributed import is_root -from weathergen.utils.train_logger import TRAIN, VAL, Stage, TrainLogger, prepare_losses_for_logging +from weathergen.utils.train_logger import TrainLogger, prepare_losses_for_logging from weathergen.utils.utils import get_dtype from weathergen.utils.validation_io import write_output logger = logging.getLogger(__name__) +# cfg_keys_to_filter = ["losses", "model_input", "target_input"] + class Trainer(TrainerBase): def __init__(self, train_log_freq: Config): @@ -74,6 +84,7 @@ def __init__(self, train_log_freq: Config): self.batch_size_per_gpu = -1 self.batch_size_validation_per_gpu = -1 self.batch_size_test_per_gpu = -1 + self.collapse_monitor: CollapseMonitor | None = None def get_batch_size_total(self, batch_size_per_gpu) -> int: """ @@ -99,22 +110,21 @@ def init(self, cf: Config, devices): self.freeze_modules = cf.get("freeze_modules", "") - # keys to filter for enabled/disabled - keys_to_filter = ["losses", "model_input", "target_input"] - # get training config and remove disabled options (e.g. because of overrides) self.training_cfg = cf.get("training_config") - self.training_cfg = filter_config_by_enabled(self.training_cfg, keys_to_filter) + self.training_cfg = filter_config_by_enabled(self.training_cfg, cfg_keys_to_filter) assert len(self.training_cfg.model_input.keys()) != 0, ( "You probably have no loss term enabled" ) # validation and test configs are training configs, updated by specified keys - self.validation_cfg = merge_configs(self.training_cfg, cf.get("validation_config", {})) - self.validation_cfg = filter_config_by_enabled(self.validation_cfg, keys_to_filter) + self.validation_cfg = get_active_stage_config( + self.training_cfg, cf.get("validation_config", {}), cfg_keys_to_filter + ) # test cfg is derived from validation cfg with specified keys overwritten - self.test_cfg = merge_configs(self.validation_cfg, cf.get("test_config", {})) - self.test_cfg = filter_config_by_enabled(self.test_cfg, keys_to_filter) + self.test_cfg = get_active_stage_config( + self.validation_cfg, cf.get("test_config", {}), cfg_keys_to_filter + ) # batch sizes self.batch_size_per_gpu = get_batch_size_from_config(self.training_cfg) @@ -146,6 +156,10 @@ def init(self, cf: Config, devices): self.train_logger = TrainLogger(cf, config.get_path_run(self.cf)) + # Initialize collapse monitor for SSL training + collapse_config = self.training_cfg.get("collapse_monitoring", {}) + self.collapse_monitor = CollapseMonitor(collapse_config, None) # device set later in run() + def get_target_aux_calculators(self, mode_cfg): """ Get target_aux_calculators for given mode_cfg @@ -227,6 +241,9 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): device_type = torch.accelerator.current_accelerator() self.device = torch.device(f"{device_type}:{cf.local_rank}") + # Update collapse monitor device + self.collapse_monitor.device = self.device + # create data loaders self.dataset = MultiStreamDataSampler(cf, self.training_cfg, stage=TRAIN) self.dataset_val = MultiStreamDataSampler(cf, self.validation_cfg, stage=VAL) @@ -289,22 +306,14 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): if not cf.with_ddp: self.model.print_num_parameters() - # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ - # aiming for beta1=0.9 and beta2=0.95 following the MAE paper - # https://arxiv.org/pdf/2111.06377 - kappa = self.get_batch_size_total(self.batch_size_per_gpu) - # aiming for beta1 = 0.9 at one node, ie kappa=B=4 - beta1 = max(0.5, 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta1)) - # aiming for beta2 = 0.95 at one node, ie B=4 - beta2 = 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta2) - eps = self.training_cfg.optimizer.adamw.get("eps", 2e-08) / np.sqrt(kappa) - - self.optimizer = torch.optim.AdamW( - self.model.parameters(), - lr=self.training_cfg.learning_rate_scheduling.lr_start, - weight_decay=self.training_cfg.optimizer.weight_decay, - betas=(beta1, beta2), - eps=eps, + # Create optimizer using factory function + # Supports both standard AdamW and hybrid Muon+AdamW configurations + # See: https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + self.optimizer = create_optimizer( + model=self.model, + optimizer_cfg=self.training_cfg.optimizer, + lr_cfg=self.training_cfg.learning_rate_scheduling, + batch_size_total=self.get_batch_size_total(self.batch_size_per_gpu), ) self.grad_scaler = torch.amp.GradScaler("cuda") @@ -501,9 +510,16 @@ def train(self, mini_epoch): if self.validate_with_ema: self.ema_model.update(self.cf.general.istep * batch_size_total, batch_size_total) + # Compute collapse monitoring metrics + if self.collapse_monitor.should_compute(self.cf.general.istep): + self._compute_collapse_metrics(preds, targets_and_auxs) + self._log_terminal(bidx, mini_epoch, TRAIN) if bidx % self.train_log_freq.metrics == 0: self._log(TRAIN) + # Log collapse metrics + if self.collapse_monitor.should_log(self.cf.general.istep): + self._log_collapse_metrics(TRAIN) # save model checkpoint (with designation _latest) if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: @@ -618,6 +634,12 @@ def _get_full_model_state_dict(self): def _get_full_optimizer_state_dict(self): is_rank_zero = is_root() + + # Handle CompositeOptimizer (Muon+AdamW) separately + if isinstance(self.optimizer, CompositeOptimizer): + return self._get_full_composite_optimizer_state_dict(is_rank_zero) + + # Standard optimizer (AdamW) handling sharded_sd = self.optimizer.state_dict() sharded_state = sharded_sd["state"] full_state = {} @@ -646,6 +668,50 @@ def _get_full_optimizer_state_dict(self): else: return {} + def _get_full_composite_optimizer_state_dict(self, is_rank_zero: bool): + """ + Get full optimizer state dict for CompositeOptimizer (Muon+AdamW). + + Handles DTensor consolidation for both sub-optimizers. + """ + + def consolidate_optimizer_state(optimizer): + """Consolidate sharded state from a single optimizer.""" + sharded_sd = optimizer.state_dict() + sharded_state = sharded_sd["state"] + full_state = {} + for group_id, sharded_group in sharded_state.items(): + group_state = {} + for attr, sharded_tensor in sharded_group.items(): + if isinstance(sharded_tensor, DTensor): + full_tensor = sharded_tensor.full_tensor() + else: + full_tensor = sharded_tensor + if is_rank_zero: + group_state[attr] = full_tensor.cpu() + else: + del full_tensor + if is_rank_zero: + full_state[group_id] = group_state + else: + del group_state + if is_rank_zero: + return { + "param_groups": sharded_sd["param_groups"], + "state": full_state, + } + return {} + + if is_rank_zero: + return { + "optimizer_type": "composite_muon_adamw", + "muon": consolidate_optimizer_state(self.optimizer.muon_optimizer), + "adamw": consolidate_optimizer_state(self.optimizer.adamw_optimizer), + "muon_lr_multiplier": self.optimizer.muon_lr_multiplier, + } + else: + return {} + def save_model(self, mini_epoch: int, name=None): # Saving at mini_epoch == max_mini_epoch means that we are saving the latest checkpoint. max_mini_epoch = self.training_cfg.num_mini_epochs @@ -775,3 +841,106 @@ def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): logger.info("\n") self.t_start = time.time() + + def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: + """ + Extract latent tensors from predictions and targets, then compute collapse metrics. + + This method extracts the student and teacher latent representations from the + SSL training outputs and passes them to the collapse monitor. + """ + # Get student latents from predictions (first forecast step) + student_latent = None + teacher_latent = None + prototype_probs = None + ema_beta = None + loss_type = None + + # Find SSL loss type and extract latents + for _loss_name, target_aux in targets_and_auxs.items(): + # Check if this is an EMATeacher-based loss + if hasattr(target_aux, "latent") and target_aux.latent: + # Handle both cases: + # 1. latent is a list[dict] (as per TargetAuxOutput dataclass) + # 2. latent is a dict (as set directly by EMATeacher) + if isinstance(target_aux.latent, list): + target_latent_dict = target_aux.latent[0] if target_aux.latent else {} + else: + # EMATeacher sets latent directly as a dict + target_latent_dict = target_aux.latent + + # Determine the SSL loss type (JEPA, DINO, iBOT) + for ssl_type in ["JEPA", "DINO", "iBOT"]: + if ssl_type in target_latent_dict: + loss_type = ssl_type + # Get teacher latent + teacher_latent_data = target_latent_dict[ssl_type] + if isinstance(teacher_latent_data, list) and len(teacher_latent_data) > 0: + teacher_latent = teacher_latent_data[0] + elif isinstance(teacher_latent_data, dict): + # Handle LatentState or dict + teacher_latent = teacher_latent_data.get( + "latent", teacher_latent_data + ) + else: + teacher_latent = teacher_latent_data + break + + # Get student latents from predictions + if preds.latent and len(preds.latent) > 0: + pred_latent_dict = preds.latent[0] + for ssl_type in ["JEPA", "DINO", "iBOT"]: + if ssl_type in pred_latent_dict: + student_latent_data = pred_latent_dict[ssl_type] + if isinstance(student_latent_data, list) and len(student_latent_data) > 0: + student_latent = student_latent_data[0] + elif isinstance(student_latent_data, dict): + student_latent = student_latent_data.get("latent", student_latent_data) + else: + student_latent = student_latent_data + loss_type = ssl_type + break + + # Get EMA beta from target_and_aux_calculators + for _calc_name, calculator in self.target_and_aux_calculators.items(): + if isinstance(calculator, EMATeacher): + batch_size_total = self.get_batch_size_total(self.batch_size_per_gpu) + step = batch_size_total * self.cf.general.istep + ema_beta = calculator.get_current_beta(step) + break + + # Debug logging for tensor extraction + if student_latent is not None: + shape = student_latent.shape if isinstance(student_latent, torch.Tensor) else "N/A" + logger.debug(f"Collapse monitor - student: type={type(student_latent)}, shape={shape}") + else: + logger.debug("Collapse monitor - student_latent is None") + + if teacher_latent is not None: + shape = teacher_latent.shape if isinstance(teacher_latent, torch.Tensor) else "N/A" + logger.debug(f"Collapse monitor - teacher: type={type(teacher_latent)}, shape={shape}") + else: + logger.debug("Collapse monitor - teacher_latent is None") + + # Ensure tensors are properly formatted + if student_latent is not None and isinstance(student_latent, torch.Tensor): + self.collapse_monitor.compute_metrics( + student_latent=student_latent, + teacher_latent=teacher_latent if isinstance(teacher_latent, torch.Tensor) else None, + prototype_probs=prototype_probs, + ema_beta=ema_beta, + loss_type=loss_type, + ) + else: + logger.debug( + f"Collapse monitor - skipping compute_metrics: " + f"student_latent is {'None' if student_latent is None else type(student_latent)}" + ) + + def _log_collapse_metrics(self, stage: Stage) -> None: + """ + Log cached collapse monitoring metrics. + """ + metrics = self.collapse_monitor.get_cached_metrics() + if metrics and is_root(): + self.train_logger.log_metrics(stage, metrics) diff --git a/src/weathergen/train/utils.py b/src/weathergen/train/utils.py index b3ddba5b0..81c8d0ae9 100644 --- a/src/weathergen/train/utils.py +++ b/src/weathergen/train/utils.py @@ -9,11 +9,23 @@ import copy import json +from typing import Literal import torch +from omegaconf import OmegaConf from weathergen.common import config -from weathergen.common.config import Config +from weathergen.common.config import Config, merge_configs + +# Run stages +Stage = Literal["train", "val", "test"] +TRAIN: Stage = "train" +VAL: Stage = "val" +TEST: Stage = "test" + +# keys to filter using enabled: True/False +cfg_keys_to_filter = ["losses", "model_input", "target_input"] + # TODO: remove this definition, it should directly using common. get_run_id = config.get_run_id @@ -149,7 +161,21 @@ def get_target_idxs_from_cfg(cfg, loss_name) -> list[int] | None: return target_idxs -def filter_config_by_enabled(cfg, keys): +def get_active_stage_config( + base_config: dict | OmegaConf, merge_config: dict | OmegaConf, keys_to_filter: list[str] +) -> dict | OmegaConf: + """ + Combine a stage config with its predecessor and filter by enabled: False to obtain the + final config that is used + """ + + result_cfg = merge_configs(base_config, merge_config) + result_cfg = filter_config_by_enabled(result_cfg, keys_to_filter) + + return result_cfg + + +def filter_config_by_enabled(cfg: dict | OmegaConf, keys: list[str]): """ Filtered disabled entries from config """ diff --git a/src/weathergen/utils/cli.py b/src/weathergen/utils/cli.py index bc98ba11d..2bd9fe2a2 100644 --- a/src/weathergen/utils/cli.py +++ b/src/weathergen/utils/cli.py @@ -1,19 +1,65 @@ import argparse +import enum from pathlib import Path import pandas as pd +class Stage(enum.StrEnum): + train = enum.auto() + train_continue = enum.auto() + inference = enum.auto() + + +def get_main_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(allow_abbrev=False) + subparsers = parser.add_subparsers(dest="stage") + + train_parser = subparsers.add_parser( + Stage.train, + help="Train a WeatherGenerator configuration from the ground up.", + ) + _add_train_args(train_parser) + continue_parser = subparsers.add_parser( + Stage.train_continue, + help="Resume training from a pretrained WeatherGenerator configuration.", + ) + _add_continue_args(continue_parser) + inference_parser = subparsers.add_parser( + Stage.inference, + help="Run infernce on a trained WeatherGenerator configuration", + ) + _add_inference_args(inference_parser) + + return parser + + def get_train_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(allow_abbrev=False) - _add_general_arguments(parser) + _add_train_args(parser) return parser def get_continue_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(allow_abbrev=False) + _add_continue_args(parser) + return parser + + +def get_inference_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(allow_abbrev=False) + _add_inference_args(parser) + + return parser + + +def _add_train_args(parser: argparse.ArgumentParser): + _add_general_arguments(parser) + + +def _add_continue_args(parser: argparse.ArgumentParser): _add_general_arguments(parser) _add_model_loading_params(parser) @@ -26,12 +72,8 @@ def get_continue_parser() -> argparse.ArgumentParser: ), ) - return parser - - -def get_inference_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(allow_abbrev=False) +def _add_inference_args(parser: argparse.ArgumentParser): _add_model_loading_params(parser) _add_general_arguments(parser) @@ -64,8 +106,6 @@ def get_inference_parser() -> argparse.ArgumentParser: help="Output streams during inference.", ) - return parser - def _format_date(date: str) -> str: try: diff --git a/src/weathergen/utils/metrics.py b/src/weathergen/utils/metrics.py index aedb48739..22e8745de 100644 --- a/src/weathergen/utils/metrics.py +++ b/src/weathergen/utils/metrics.py @@ -61,4 +61,4 @@ def get_train_metrics_path(base_path: Path, run_id: str) -> Path: if (base_path / run_id / "metrics.json").exists(): return base_path / run_id / "metrics.json" else: - return base_path / run_id / f"{run_id}_train_metrics.json" + return base_path / f"{run_id}_train_metrics.json" diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 35bfafe3e..bd54a564c 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -9,8 +9,10 @@ import argparse import logging +import pdb import subprocess import sys +import traceback from pathlib import Path import matplotlib.pyplot as plt @@ -167,7 +169,7 @@ def get_stream_names(run_id: str, model_path: Path | None = "./model"): List of stream names """ # return col names from training (should be identical to validation) - cf = config.load_run_config(run_id, -1, model_path=model_path) + cf = config.load_run_config(run_id, None, model_path=model_path) return [si["name"].replace(",", "").replace("/", "_").replace(" ", "_") for si in cf.streams] @@ -316,6 +318,49 @@ def plot_utilization( plt.close() +def plot_loss_avg(plot_dir: Path, runs_ids, runs_data, x_scale_log=False): + prop_cycle = plt.rcParams["axes.prop_cycle"] + colors = prop_cycle.by_key()["color"] + ["r", "g", "b", "k", "y", "m"] + + # # legend = plt.legend(legend_str, loc="upper right" if not x_scale_log else "lower left") + # for line in legend.get_lines(): + # line.set(alpha=1.0) + _fig = plt.figure(figsize=(10, 7), dpi=300) + + legend_str = [] + for i_run, (run_id, run_data) in enumerate(zip(runs_ids, runs_data, strict=False)): + x_vals = np.array(run_data.train["num_samples"]) + y_vals = np.array(run_data.train["loss_avg_mean"]) + plt.plot( + x_vals, + y_vals, + color=colors[i_run % len(colors)], + ) + legend_str += [run_id + " : " + runs_ids[run_id][1]] + # ("R" if runs_active[j] else "X") + # + " : " + # run_id + ", " + col + " : " + runs_ids[run_id][1] + # ] + + plt.legend(legend_str) + plt.grid(True, which="both", ls="-") + plt.yscale("log") + # cap at 1.0 in case of divergence of run (through normalziation, max should be around 1.0) + # plt.ylim([0.95 * min_val, (None if max_val < 2.0 else min(1.1, 1.025 * max_val))]) + if x_scale_log: + plt.xscale("log") + plt.title("average loss") + plt.ylabel("loss") + plt.xlabel("step") + plt.tight_layout() + rstr = "".join([f"{r}_" for r in runs_ids]) + + plt_fname = plot_dir / f"{rstr}avg.png" + _logger.info(f"Saving avg plot to '{plt_fname}'") + plt.savefig(plt_fname) + plt.close() + + #################################################################################################### def plot_loss_per_stream( modes: list[str], @@ -357,7 +402,7 @@ def plot_loss_per_stream( """ if errs is None: - errs = ["loss_mse"] + errs = ["mse"] modes = [modes] if type(modes) is not list else modes # repeat colors when train and val is plotted simultaneously @@ -688,6 +733,9 @@ def plot_train(args=None): # plot learning rate plot_lr(runs_ids, runs_data, runs_active, plot_dir=out_dir) + # plot average loss + plot_loss_avg(out_dir, runs_ids, runs_data) + # # plot performance # plot_utilization(runs_ids, runs_data, runs_active, plot_dir=out_dir) @@ -746,4 +794,9 @@ def plot_train(args=None): if __name__ == "__main__": args = sys.argv[1:] # get CLI args - plot_train(args) + try: + plot_train(args) + except Exception: + extype, value, tb = sys.exc_info() + traceback.print_exc() + pdb.post_mortem(tb) diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 2812134a0..f2313baa4 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -16,14 +16,15 @@ from collections import defaultdict from dataclasses import dataclass from pathlib import Path -from typing import Literal import numpy as np import polars as pl import torch import weathergen.common.config as config -from weathergen.train.utils import flatten_dict + +# from weathergen.train.trainer import cfg_keys_to_filter +from weathergen.train.utils import Stage, cfg_keys_to_filter, flatten_dict, get_active_stage_config from weathergen.utils.distributed import ddp_average from weathergen.utils.metrics import get_train_metrics_path, read_metrics_file @@ -35,13 +36,8 @@ _logger = logging.getLogger(__name__) -Stage = Literal["train", "val"] RunId = str -# All the stages currently implemented: -TRAIN: Stage = "train" -VAL: Stage = "val" - @dataclass class Metrics: @@ -91,7 +87,7 @@ def log_metrics(self, stage: Stage, metrics: dict[str, float]) -> None: # but we can probably do better and rely for example on the logging module. metrics_path = get_train_metrics_path( - base_path=config.get_path_run(self.cf).parent, run_id=self.cf.general.run_id + base_path=config.get_path_run(self.cf), run_id=self.cf.general.run_id ) with open(metrics_path, "ab") as f: s = json.dumps(clean_metrics) + "\n" @@ -131,10 +127,11 @@ def add_logs( ####################################### @staticmethod - def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: + def read(run_id: str, model_path: str = None, mini_epoch: int | None = None) -> Metrics: """ Read data for run_id """ + # Load config from given model_path if provided, otherwise use path from private config if model_path: cf = config.load_run_config(run_id=run_id, mini_epoch=mini_epoch, model_path=model_path) @@ -148,28 +145,15 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: result_dir = result_dir_base / run_id fname_log_train = result_dir / f"{run_id}_train_log.txt" fname_log_val = result_dir / f"{run_id}_val_log.txt" - fname_perf_val = result_dir / f"{run_id}_perf_log.txt" # training # define cols for training - cols_train = ["dtime", "samples", "mse", "lr"] - cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"] - for si in cf.streams: - for lf in cf.loss_fcts: - cols1 += [_key_loss(si["name"], lf[0])] - cols_train += [ - si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0] - ] - with_stddev = [("stats" in lf) for lf in cf.loss_fcts] - if with_stddev: - for si in cf.streams: - cols1 += [_key_stddev(si["name"])] - cols_train += [ - si["name"].replace(",", "").replace("/", "_").replace(" ", "_") - + ", " - + "stddev" - ] + training_cfg = get_active_stage_config(cf.training_config, {}, cfg_keys_to_filter) + cols1, cols_train = get_loss_terms_per_stream(cf.streams, training_cfg) + cols_train += ["dtime", "samples", "mse", "lr"] + cols1 += [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"] + # read training log data try: with open(fname_log_train, "rb") as f: @@ -211,23 +195,13 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: # validation # define cols for validation + validation_cfg = get_active_stage_config( + training_cfg, cf.get("validation_config", {}), cfg_keys_to_filter + ) + cols2, cols_val = get_loss_terms_per_stream(cf.streams, validation_cfg) cols_val = ["dtime", "samples"] cols2 = [_weathergen_timestamp, "num_samples"] - for si in cf.streams: - for lf in cf.loss_fcts_val: - cols_val += [ - si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0] - ] - cols2 += [_key_loss(si["name"], lf[0])] - with_stddev = [("stats" in lf) for lf in cf.loss_fcts_val] - if with_stddev: - for si in cf.streams: - cols2 += [_key_stddev(si["name"])] - cols_val += [ - si["name"].replace(",", "").replace("/", "_").replace(" ", "_") - + ", " - + "stddev" - ] + # read validation log data try: with open(fname_log_val, "rb") as f: @@ -266,54 +240,7 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: log_val = np.array([]) metrics_val_df = read_metrics(cf, run_id, "val", cols2, result_dir_base) - # performance - # define cols for performance monitoring - cols_perf = ["GPU", "memory"] - # read perf log data - try: - with open(fname_perf_val, "rb") as f: - log_perf = np.loadtxt(f, delimiter=",") - log_perf = log_perf.reshape((log_perf.shape[0] // len(cols_perf), len(cols_perf))) - except ( - TypeError, - AttributeError, - IndexError, - ZeroDivisionError, - ValueError, - ) as e: - _logger.warning( - ( - f"Warning: no validation data loaded for run_id={run_id}", - "Data loading or reshaping failed — " - "possible format, dimension, or logic issue.", - f"Due to specific error: {e}", - ) - ) - except (FileNotFoundError, PermissionError, OSError) as e: - _logger.error( - ( - f"Error: no validation data loaded for run_id={run_id}", - "File system error occurred while handling the log file.", - f"Due to specific error: {e}", - ) - ) - except Exception: - _logger.error( - ( - f"Error: no validation data loaded for run_id={run_id}", - f"Due to exception with trace:\n{traceback.format_exc()}", - ) - ) - log_perf = np.array([]) - metrics_system_df = read_metrics( - cf, - run_id, - None, - [_weathergen_timestamp, _performance_gpu, _performance_memory], - result_dir_base, - ) - - return Metrics(run_id, "train", log_train_df, metrics_val_df, metrics_system_df) + return Metrics(run_id, "train", log_train_df, metrics_val_df, None) def read_metrics( @@ -391,9 +318,27 @@ def clean_name(s: str) -> str: return "".join(c for c in s if c.isalnum() or c == "_") +def get_loss_terms_per_stream(streams, stage_config): + """ + Extract per stream loss terms + """ + cols, cols_stage = [], [] + for si in streams: + for _, loss_config in stage_config.get("losses", {}).items(): + if loss_config.get("type", "LossPhysical") == "LossPhysical": + for lname, _ in loss_config.loss_fcts.items(): + cols += [_key_loss(si["name"], lname)] + cols_stage += [_clean_stream_name(si["name"]) + lname] + return cols, cols_stage + + +def _clean_stream_name(stream_name: str) -> str: + return stream_name.replace(",", "").replace("/", "_").replace(" ", "_") + ", " + + def _key_loss(st_name: str, lf_name: str) -> str: st_name = clean_name(st_name) - return f"stream.{st_name}.loss_{lf_name}.loss_avg" + return f"LossPhysical.{st_name}.{lf_name}.avg" def _key_loss_chn(st_name: str, lf_name: str, ch_name: str) -> str: diff --git a/src/weathergen/utils/utils.py b/src/weathergen/utils/utils.py index aee807341..291ab1521 100644 --- a/src/weathergen/utils/utils.py +++ b/src/weathergen/utils/utils.py @@ -10,7 +10,7 @@ import torch -from weathergen.utils.train_logger import TRAIN, Stage +from weathergen.train.utils import TRAIN, Stage def get_dtype(value: str) -> torch.dtype: diff --git a/tests/test_collapse_monitor.py b/tests/test_collapse_monitor.py new file mode 100644 index 000000000..6a6f2ed8c --- /dev/null +++ b/tests/test_collapse_monitor.py @@ -0,0 +1,417 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Unit tests for collapse monitoring metrics.""" + +import pytest +import torch + +from weathergen.train.collapse_monitor import CollapseMonitor + + +@pytest.fixture +def default_config(): + """Default enabled config for collapse monitoring.""" + return { + "enabled": True, + "compute_frequency": 100, + "log_frequency": 100, + "metrics": { + "effective_rank": { + "enabled": True, + "tensor_source": "both", + "sample_size": 2048, + }, + "singular_values": { + "enabled": True, + "tensor_source": "both", + "sample_size": 2048, + }, + "dimension_variance": { + "enabled": True, + "tensor_source": "both", + }, + "prototype_entropy": { + "enabled": True, + }, + "ema_beta": { + "enabled": True, + }, + }, + } + + +@pytest.fixture +def monitor(default_config): + """Create a collapse monitor with default config.""" + device = torch.device("cpu") + return CollapseMonitor(default_config, device) + + +class TestCollapseMonitorInitialization: + """Test CollapseMonitor initialization.""" + + def test_disabled_monitor(self): + """Test that disabled monitor doesn't compute metrics.""" + config = {"enabled": False} + monitor = CollapseMonitor(config, torch.device("cpu")) + assert not monitor.enabled + assert not monitor.should_compute(100) + assert not monitor.should_log(100) + + def test_enabled_monitor(self, default_config): + """Test that enabled monitor computes at correct intervals.""" + monitor = CollapseMonitor(default_config, torch.device("cpu")) + assert monitor.enabled + assert monitor.should_compute(0) + assert monitor.should_compute(100) + assert not monitor.should_compute(50) + + def test_frequency_settings(self): + """Test custom frequency settings.""" + config = { + "enabled": True, + "compute_frequency": 50, + "log_frequency": 200, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + assert monitor.should_compute(50) + assert monitor.should_compute(100) # 100 is a multiple of 50 + assert not monitor.should_compute(75) # 75 is not a multiple of 50 + assert monitor.should_log(200) + assert not monitor.should_log(100) + + +class TestEffectiveRank: + """Test effective rank computation.""" + + def test_full_rank_matrix(self, monitor): + """Full rank random matrix should have effective rank close to min(N, D).""" + torch.manual_seed(42) + # Create a full-rank matrix with orthogonal columns + dim = 64 + num_samples = 128 + z = torch.randn(num_samples, dim) + # Make it more orthogonal via QR decomposition + q, _ = torch.linalg.qr(z.T) + z = q.T # Now z is [dim, dim] with orthogonal rows + z = torch.cat([z, torch.randn(num_samples - dim, dim)], dim=0) + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # For a full-rank matrix, effective rank should be significant portion of D + assert eff_rank > dim * 0.3, f"Expected effective rank > {dim * 0.3}, got {eff_rank}" + + def test_low_rank_matrix(self, monitor): + """Low rank matrix should have effective rank close to actual rank.""" + torch.manual_seed(42) + # Create a rank-5 matrix + actual_rank = 5 + num_samples, dim = 128, 64 + u_mat = torch.randn(num_samples, actual_rank) + v_mat = torch.randn(actual_rank, dim) + z = u_mat @ v_mat + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # Effective rank should be close to actual rank + assert eff_rank < actual_rank * 2, ( + f"Expected effective rank < {actual_rank * 2}, got {eff_rank}" + ) + assert eff_rank > actual_rank * 0.5, ( + f"Expected effective rank > {actual_rank * 0.5}, got {eff_rank}" + ) + + def test_collapsed_matrix(self, monitor): + """Completely collapsed matrix should have effective rank ~1.""" + num_samples, dim = 128, 64 + # All rows are the same (rank 1) + row = torch.randn(1, dim) + z = row.expand(num_samples, dim).clone() + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # Effective rank should be very close to 1 + assert eff_rank < 2, f"Expected effective rank < 2, got {eff_rank}" + + def test_3d_tensor_flattening(self, monitor): + """Test that [B, N, D] tensors are properly flattened.""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + z = torch.randn(batch_size, num_patches, dim) + + eff_rank = monitor._compute_effective_rank(z, sample_size=0) + # Should compute without error and return reasonable value + assert 1 <= eff_rank <= dim + + +class TestSingularValues: + """Test singular value spectrum computation.""" + + def test_singular_value_statistics(self, monitor): + """Test that singular value statistics are correctly computed.""" + torch.manual_seed(42) + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + + sv_metrics = monitor._compute_singular_values(z, sample_size=0) + + # Check that we got min, max, mean statistics + assert "sv_min" in sv_metrics + assert "sv_max" in sv_metrics + assert "sv_mean" in sv_metrics + assert "sv_concentration" in sv_metrics + + # Max should be >= mean >= min + assert sv_metrics["sv_max"] >= sv_metrics["sv_mean"] + assert sv_metrics["sv_mean"] >= sv_metrics["sv_min"] + + def test_concentration_ratio(self, monitor): + """Test singular value concentration ratio.""" + torch.manual_seed(42) + # Create a rank-1 matrix where first SV dominates + num_samples, dim = 128, 64 + # Use outer product to create a truly rank-1 dominated matrix + u_vec = torch.randn(num_samples, 1) + v_vec = torch.randn(1, dim) + z = u_vec @ v_vec * 10 + torch.randn(num_samples, dim) * 0.01 # Strong rank-1 component + + sv_metrics = monitor._compute_singular_values(z, sample_size=0) + + # Concentration should be high when one SV dominates + assert "sv_concentration" in sv_metrics + assert sv_metrics["sv_concentration"] > 0.8 # First SV dominates strongly + + # Max should be much larger than min for rank-1 dominated matrix + assert sv_metrics["sv_max"] > sv_metrics["sv_min"] * 10 + + def test_uniform_singular_values(self, monitor): + """Test with random matrix (spread singular values).""" + torch.manual_seed(42) + # Random matrix will have spread singular values + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + + sv_metrics = monitor._compute_singular_values(z, sample_size=0) + + # Concentration should be relatively low for random matrix + assert sv_metrics["sv_concentration"] < 0.2 + + # All statistics should be positive + assert sv_metrics["sv_min"] > 0 + assert sv_metrics["sv_max"] > 0 + assert sv_metrics["sv_mean"] > 0 + + +class TestDimensionVariance: + """Test per-dimension variance computation.""" + + def test_random_matrix_balanced_variance(self, monitor): + """Random matrix should have balanced variance across dimensions.""" + torch.manual_seed(42) + num_samples, dim = 1024, 64 + z = torch.randn(num_samples, dim) + + var_metrics = monitor._compute_dimension_variance(z) + + # All variances should be close to 1 for standard normal + assert abs(var_metrics["var_mean"] - 1.0) < 0.2 + # Variance ratio should be small for random matrix + var_ratio = var_metrics["var_max"] / (var_metrics["var_min"] + 1e-8) + assert var_ratio < 5 # Balanced dimensions + + def test_dead_dimensions(self, monitor): + """Test detection of dead (zero-variance) dimensions.""" + torch.manual_seed(42) + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + # Kill some dimensions (set to constant) + z[:, :10] = 0.5 + + var_metrics = monitor._compute_dimension_variance(z) + + # Minimum variance should be very close to 0 (dead dimensions) + assert var_metrics["var_min"] < 1e-6 + + def test_imbalanced_dimensions(self, monitor): + """Test with highly imbalanced dimension variances.""" + torch.manual_seed(42) + num_samples, dim = 128, 64 + z = torch.randn(num_samples, dim) + # Scale some dimensions much more than others + z[:, 0] *= 100 + z[:, 1:10] *= 0.01 + + var_metrics = monitor._compute_dimension_variance(z) + + # Large variance ratio indicates imbalance + var_ratio = var_metrics["var_max"] / (var_metrics["var_min"] + 1e-8) + assert var_ratio > 1000 + + +class TestPrototypeEntropy: + """Test DINO prototype entropy computation.""" + + def test_uniform_prototype_distribution(self, monitor): + """Uniform prototype distribution should have entropy ~1.""" + batch_size, num_prototypes = 128, 64 + # Uniform distribution + probs = torch.ones(batch_size, num_prototypes) / num_prototypes + + entropy = monitor._compute_prototype_entropy(probs) + + # Normalized entropy should be close to 1 + assert abs(entropy - 1.0) < 0.01 + + def test_single_prototype_collapse(self, monitor): + """Collapse to single prototype should have entropy ~0.""" + batch_size, num_prototypes = 128, 64 + # All mass on first prototype + probs = torch.zeros(batch_size, num_prototypes) + probs[:, 0] = 1.0 + + entropy = monitor._compute_prototype_entropy(probs) + + # Normalized entropy should be close to 0 + assert entropy < 0.01 + + def test_partial_collapse(self, monitor): + """Partial collapse should have intermediate entropy.""" + batch_size, num_prototypes = 128, 64 + # Only 4 prototypes used uniformly (much stronger collapse) + probs = torch.zeros(batch_size, num_prototypes) + probs[:, :4] = 0.25 # Only 4 out of 64 prototypes + + entropy = monitor._compute_prototype_entropy(probs) + + # Entropy should be between 0 and 1 (log(4)/log(64) ≈ 0.33) + assert 0.2 < entropy < 0.5 + + +class TestMetricsCaching: + """Test metrics caching and averaging.""" + + def test_cache_accumulation(self, monitor): + """Test that metrics are properly cached.""" + torch.manual_seed(42) + z1 = torch.randn(64, 32) + z2 = torch.randn(64, 32) + + # Compute metrics twice + monitor.compute_metrics(student_latent=z1) + monitor.compute_metrics(student_latent=z2) + + # Cache should contain averaged values + cached = monitor.get_cached_metrics() + assert "collapse.student.effective_rank" in cached + + def test_cache_clear(self, monitor): + """Test that cache is cleared after get_cached_metrics.""" + torch.manual_seed(42) + z = torch.randn(64, 32) + + monitor.compute_metrics(student_latent=z) + _ = monitor.get_cached_metrics() + + # Second call should return empty + cached = monitor.get_cached_metrics() + assert len(cached) == 0 + + +class TestIntegration: + """Integration tests with both student and teacher tensors.""" + + def test_full_metrics_computation(self, monitor): + """Test computing all metrics with both student and teacher.""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + student = torch.randn(batch_size, num_patches, dim) + teacher = torch.randn(batch_size, num_patches, dim) + + metrics = monitor.compute_metrics( + student_latent=student, + teacher_latent=teacher, + ema_beta=0.999, + loss_type="JEPA", + ) + + # Check that both student and teacher metrics are computed + assert "collapse.student.effective_rank" in metrics + assert "collapse.teacher.effective_rank" in metrics + assert "collapse.student.var_min" in metrics + assert "collapse.teacher.var_min" in metrics + assert "collapse.ema_beta" in metrics + assert metrics["collapse.ema_beta"] == 0.999 + + def test_dino_prototype_entropy(self, monitor): + """Test DINO prototype entropy computation.""" + torch.manual_seed(42) + batch_size, num_patches, dim = 4, 32, 64 + num_prototypes = 128 + student = torch.randn(batch_size, num_patches, dim) + probs = torch.softmax(torch.randn(batch_size, num_prototypes), dim=-1) + + metrics = monitor.compute_metrics( + student_latent=student, + prototype_probs=probs, + loss_type="DINO", + ) + + assert "collapse.dino.prototype_entropy" in metrics + assert 0 <= metrics["collapse.dino.prototype_entropy"] <= 1 + + def test_disabled_metrics(self): + """Test that disabled metrics are not computed.""" + config = { + "enabled": True, + "compute_frequency": 1, + "log_frequency": 1, + "metrics": { + "effective_rank": {"enabled": False}, + "singular_values": {"enabled": False}, + "dimension_variance": {"enabled": True, "tensor_source": "student"}, + "prototype_entropy": {"enabled": False}, + "ema_beta": {"enabled": False}, + }, + } + monitor = CollapseMonitor(config, torch.device("cpu")) + + torch.manual_seed(42) + z = torch.randn(64, 32) + metrics = monitor.compute_metrics(student_latent=z) + + # Only dimension variance should be computed + assert "collapse.student.var_min" in metrics + assert "collapse.student.effective_rank" not in metrics + assert "collapse.student.sv_max" not in metrics + + +class TestSampling: + """Test row sampling for SVD computations.""" + + def test_sampling_reduces_computation(self, monitor): + """Test that sampling works for large tensors.""" + torch.manual_seed(42) + num_samples, dim = 10000, 64 + z = torch.randn(num_samples, dim) + + # With sampling + eff_rank_sampled = monitor._compute_effective_rank(z, sample_size=1024) + # Without sampling + eff_rank_full = monitor._compute_effective_rank(z, sample_size=0) + + # Results should be in same ballpark + assert abs(eff_rank_sampled - eff_rank_full) < eff_rank_full * 0.3 + + def test_no_sampling_when_small(self, monitor): + """Test that small tensors aren't sampled.""" + torch.manual_seed(42) + num_samples, dim = 100, 64 + z = torch.randn(num_samples, dim) + + # Sample size larger than N + sampled = monitor._sample_rows(z, sample_size=1024) + assert sampled.shape[0] == num_samples # No sampling occurred diff --git a/tests/test_config.py b/tests/test_config.py index c5fa2e5f3..e04390341 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -9,8 +9,7 @@ TEST_RUN_ID = "test123" SECRET_COMPONENT = "53CR3T" DUMMY_PRIVATE_CONF = { - "data_path_anemoi": "/path/to/anmoi/data", - "data_path_obs": "/path/to/observation/data", + "data_paths": ["/path/to/anmoi/data", "/path/to/observation/data"] "secrets": { "my_big_secret": { "my_secret_id": f"{SECRET_COMPONENT}01234", diff --git a/tests/test_encoder_teacher.py b/tests/test_encoder_teacher.py new file mode 100644 index 000000000..bb53f234e --- /dev/null +++ b/tests/test_encoder_teacher.py @@ -0,0 +1,671 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Tests for EncoderTeacher class hierarchy (EMATeacher and FrozenTeacher).""" + +from __future__ import annotations + +import sys +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn as nn + +# Mock flash_attn before importing weathergen modules +sys.modules["flash_attn"] = MagicMock() + +from weathergen.train.target_and_aux_module_base import TargetAuxOutput # noqa: E402 + + +# ============================================================================= +# Fixtures for mock objects +# ============================================================================= + + +class MockLatentState: + """Mock latent state that get_latent_prediction returns.""" + + def __init__(self, data: dict): + self._data = data + + def __getitem__(self, key): + return self._data[key] + + +class MockModelOutput: + """Mock model output with get_latent_prediction method.""" + + def __init__(self, latent_data: dict): + self._latent_data = latent_data + + def get_latent_prediction(self, idx: int): + return self._latent_data + + +class MockSample: + """Mock sample with meta_info.""" + + def __init__(self): + self.meta_info = {"key": "value"} + + +class MockBatch: + """Mock batch for testing compute().""" + + def __init__(self, num_samples: int = 2): + self._samples = [MockSample() for _ in range(num_samples)] + + def get_samples(self): + return self._samples + + def get_output_len(self): + return 1 + + def get_output_idxs(self): + return [0] + + +class MockEMAModel: + """Mock EMA model for testing EMATeacher.""" + + def __init__(self, model: nn.Module): + self.model = model + self.ema_model = model + self.is_model_sharded = False + self._reset_called = False + self._update_called = False + self._update_args = None + + def reset(self): + self._reset_called = True + # Copy weights from model to ema_model (simulating real behavior) + with torch.no_grad(): + for p_ema, p_model in zip( + self.ema_model.parameters(), self.model.parameters() + ): + p_ema.copy_(p_model) + + def update(self, istep: int, batch_size: int): + self._update_called = True + self._update_args = (istep, batch_size) + # Simulate EMA update by slightly modifying weights + with torch.no_grad(): + for p in self.ema_model.parameters(): + p.mul_(0.999).add_(torch.randn_like(p) * 0.001) + + def forward_eval(self, model_params, batch): + return self.ema_model(model_params, batch) + + +@pytest.fixture +def simple_model(): + """Create a simple model for testing.""" + model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 5)) + return model + + +@pytest.fixture +def mock_training_cfg(): + """Create mock training config with JEPA loss.""" + from omegaconf import OmegaConf + + cfg = OmegaConf.create( + { + "losses": { + "ssl_loss": { + "type": "LossLatentSSLStudentTeacher", + "loss_fcts": {"JEPA": {"head": "identity", "out_dim": 256}}, + } + } + } + ) + return cfg + + +@pytest.fixture +def mock_ema_model(simple_model): + """Create mock EMA model wrapping simple_model.""" + return MockEMAModel(simple_model) + + +@pytest.fixture +def model_with_latent_heads(): + """Create a model with latent_heads attribute for FrozenTeacher testing.""" + model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 5)) + # Add latent_heads attribute to mimic real model structure + model.latent_heads = nn.ModuleDict({"JEPA": nn.Identity()}) + return model + + +@pytest.fixture +def model_without_latent_heads(): + """Create a model WITHOUT latent_heads (like a forecasting-only model).""" + model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 5)) + # No latent_heads - simulates a model trained without SSL + return model + + +# ============================================================================= +# Interface Tests - Both EMATeacher and FrozenTeacher must pass these +# ============================================================================= + + +class TestEncoderTeacherInterface: + """Tests for the shared interface of EncoderTeacher subclasses.""" + + def test_ema_teacher_has_required_methods(self): + """Verify EMATeacher has all required interface methods.""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + required_methods = [ + "reset", + "update_state_pre_backward", + "update_state_post_opt_step", + "compute", + "to_device", + ] + for method in required_methods: + assert hasattr(EMATeacher, method), f"EMATeacher missing method: {method}" + assert callable( + getattr(EMATeacher, method) + ), f"EMATeacher.{method} is not callable" + + def test_frozen_teacher_has_required_methods(self): + """Verify FrozenTeacher has all required interface methods.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + required_methods = [ + "reset", + "update_state_pre_backward", + "update_state_post_opt_step", + "compute", + "to_device", + ] + for method in required_methods: + assert hasattr( + FrozenTeacher, method + ), f"FrozenTeacher missing method: {method}" + assert callable( + getattr(FrozenTeacher, method) + ), f"FrozenTeacher.{method} is not callable" + + def test_ema_teacher_update_state_pre_backward_is_noop( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """Verify update_state_pre_backward returns None (no-op).""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + teacher = EMATeacher( + simple_model, mock_ema_model, batch_size=8, training_cfg=mock_training_cfg + ) + result = teacher.update_state_pre_backward( + istep=0, batch=MockBatch(), model=simple_model + ) + assert result is None + + def test_frozen_teacher_update_state_pre_backward_is_noop( + self, simple_model, model_with_latent_heads + ): + """Verify FrozenTeacher.update_state_pre_backward returns None (no-op).""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) + result = teacher.update_state_pre_backward( + istep=0, batch=MockBatch(), model=simple_model + ) + assert result is None + + def test_ema_teacher_to_device_moves_postprocessors( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """Verify to_device moves postprocessors to specified device.""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + teacher = EMATeacher( + simple_model, mock_ema_model, batch_size=8, training_cfg=mock_training_cfg + ) + + # Track if .to() was called on postprocessors + for name, module in teacher.postprocess_targets.items(): + module.to = MagicMock(return_value=module) + + teacher.to_device("cpu") + + for name, module in teacher.postprocess_targets.items(): + module.to.assert_called_once_with("cpu") + + def test_frozen_teacher_to_device_moves_postprocessors( + self, model_with_latent_heads + ): + """Verify FrozenTeacher.to_device moves postprocessors.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) + + for name, module in teacher.postprocess_targets.items(): + module.to = MagicMock(return_value=module) + + teacher.to_device("cpu") + + for name, module in teacher.postprocess_targets.items(): + module.to.assert_called_once_with("cpu") + + +# ============================================================================= +# EMATeacher-specific Tests +# ============================================================================= + + +class TestEMATeacher: + """Tests specific to EMATeacher behavior.""" + + def test_ema_reset_calls_ema_model_reset( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """After reset, EMA model's reset method should be called.""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + teacher = EMATeacher( + simple_model, mock_ema_model, batch_size=8, training_cfg=mock_training_cfg + ) + + # Reset is called in __init__, so reset the flag first + mock_ema_model._reset_called = False + + teacher.reset() + assert mock_ema_model._reset_called + + def test_ema_reset_can_update_batch_size( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """Reset can optionally update batch size.""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + teacher = EMATeacher( + simple_model, mock_ema_model, batch_size=8, training_cfg=mock_training_cfg + ) + assert teacher.batch_size == 8 + + teacher.reset(batch_size=16) + assert teacher.batch_size == 16 + + def test_ema_update_post_opt_step_calls_ema_update( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """update_state_post_opt_step should call ema_model.update().""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + teacher = EMATeacher( + simple_model, mock_ema_model, batch_size=8, training_cfg=mock_training_cfg + ) + + teacher.update_state_post_opt_step( + istep=10, batch=MockBatch(), model=simple_model + ) + + assert mock_ema_model._update_called + assert mock_ema_model._update_args == (10, 8) + + +# ============================================================================= +# FrozenTeacher-specific Tests +# ============================================================================= + + +class TestFrozenTeacher: + """Tests specific to FrozenTeacher behavior.""" + + def test_frozen_teacher_init_freezes_parameters(self, model_with_latent_heads): + """FrozenTeacher should freeze all model parameters on init.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + # Verify model starts with requires_grad=True + assert all(p.requires_grad for p in model_with_latent_heads.parameters()) + + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) + + # All parameters should be frozen + assert all(not p.requires_grad for p in teacher.teacher_model.parameters()) + + def test_frozen_teacher_init_sets_eval_mode(self, model_with_latent_heads): + """FrozenTeacher should set model to eval mode.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + model_with_latent_heads.train() + assert model_with_latent_heads.training + + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) + + assert not teacher.teacher_model.training + + def test_frozen_reset_is_noop(self, model_with_latent_heads): + """FrozenTeacher.reset() should not change weights.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) + + # Get weights before reset + weights_before = { + k: v.clone() for k, v in teacher.teacher_model.state_dict().items() + } + + teacher.reset() + + # Weights should be unchanged + weights_after = teacher.teacher_model.state_dict() + for key in weights_before: + assert torch.equal(weights_before[key], weights_after[key]) + + def test_frozen_update_is_noop(self, model_with_latent_heads): + """FrozenTeacher.update_state_post_opt_step() should not change weights.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) + + # Get weights before update + weights_before = { + k: v.clone() for k, v in teacher.teacher_model.state_dict().items() + } + + teacher.update_state_post_opt_step( + istep=10, batch=MockBatch(), model=MagicMock() + ) + + # Weights should be unchanged + weights_after = teacher.teacher_model.state_dict() + for key in weights_before: + assert torch.equal(weights_before[key], weights_after[key]) + + def test_frozen_weights_require_no_grad(self, model_with_latent_heads): + """All FrozenTeacher parameters should have requires_grad=False.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=None) + + for name, param in teacher.teacher_model.named_parameters(): + assert not param.requires_grad, f"Parameter {name} should have requires_grad=False" + + def test_frozen_model_in_eval_mode(self): + """FrozenTeacher model should always be in eval mode.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + model = nn.Sequential( + nn.Linear(10, 10), nn.BatchNorm1d(10), nn.Linear(10, 5) + ) + # Add latent_heads to model + model.latent_heads = nn.ModuleDict({"JEPA": nn.Identity()}) + model.train() # Start in train mode + + teacher = FrozenTeacher(model, training_cfg=None) + + # Model should be in eval mode + assert not teacher.teacher_model.training + # All submodules should be in eval mode + for module in teacher.teacher_model.modules(): + assert not module.training + + def test_frozen_teacher_adds_identity_heads_when_missing( + self, model_without_latent_heads, mock_training_cfg + ): + """FrozenTeacher should add identity heads if teacher lacks them.""" + from weathergen.model.engines import LatentPredictionHeadIdentity + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + # Model has no latent_heads + assert not hasattr(model_without_latent_heads, "latent_heads") + + teacher = FrozenTeacher(model_without_latent_heads, training_cfg=mock_training_cfg) + + # latent_heads should now exist with JEPA + assert hasattr(teacher.teacher_model, "latent_heads") + assert "JEPA" in teacher.teacher_model.latent_heads + assert isinstance( + teacher.teacher_model.latent_heads["JEPA"], LatentPredictionHeadIdentity + ) + + def test_frozen_teacher_uses_training_cfg_for_heads(self, model_without_latent_heads): + """FrozenTeacher should use training_cfg to determine which heads to add.""" + from omegaconf import OmegaConf + + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + # Config with both JEPA and DINO losses + cfg = OmegaConf.create( + { + "losses": { + "ssl_loss": { + "type": "LossLatentSSLStudentTeacher", + "loss_fcts": { + "JEPA": {"head": "identity"}, + "DINO": {"head": "mlp", "out_dim": 256}, + }, + } + } + } + ) + + teacher = FrozenTeacher(model_without_latent_heads, training_cfg=cfg) + + # Both heads should be added + assert "JEPA" in teacher.teacher_model.latent_heads + assert "DINO" in teacher.teacher_model.latent_heads + # Postprocessing should exist for both + assert "JEPA" in teacher.postprocess_targets + assert "DINO" in teacher.postprocess_targets + + def test_frozen_teacher_defaults_to_jepa_without_config(self, model_without_latent_heads): + """FrozenTeacher should default to JEPA head when no config provided.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + teacher = FrozenTeacher(model_without_latent_heads, training_cfg=None) + + # Should default to JEPA + assert "JEPA" in teacher.teacher_model.latent_heads + assert "JEPA" in teacher.postprocess_targets + + def test_frozen_teacher_preserves_existing_heads(self, model_with_latent_heads, mock_training_cfg): + """FrozenTeacher should not overwrite existing latent heads.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + # Get reference to original head + original_head = model_with_latent_heads.latent_heads["JEPA"] + + teacher = FrozenTeacher(model_with_latent_heads, training_cfg=mock_training_cfg) + + # Original head should be preserved (same object) + assert teacher.teacher_model.latent_heads["JEPA"] is original_head + + def test_frozen_teacher_all_postprocessing_is_identity(self, model_without_latent_heads): + """All FrozenTeacher postprocessing should use identity (JEPATargetProcessing).""" + from omegaconf import OmegaConf + + from weathergen.model.ssl_target_processing import JEPATargetProcessing + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + # Config with multiple SSL losses + cfg = OmegaConf.create( + { + "losses": { + "ssl_loss": { + "type": "LossLatentSSLStudentTeacher", + "loss_fcts": { + "JEPA": {"head": "identity"}, + "DINO": {"head": "mlp"}, + "iBOT": {"head": "mlp"}, + }, + } + } + } + ) + + teacher = FrozenTeacher(model_without_latent_heads, training_cfg=cfg) + + # All postprocessors should be JEPATargetProcessing (identity) + for name, processor in teacher.postprocess_targets.items(): + assert isinstance(processor, JEPATargetProcessing), ( + f"Postprocessor for {name} should be JEPATargetProcessing" + ) + + +# ============================================================================= +# EncoderTeacher Base Class Tests +# ============================================================================= + + +class TestEncoderTeacherBaseClass: + """Tests for EncoderTeacher base class functionality.""" + + def test_encoder_teacher_exists(self): + """Verify EncoderTeacher base class exists.""" + from weathergen.train.target_and_aux_ssl_teacher import EncoderTeacher + + assert EncoderTeacher is not None + + def test_ema_teacher_inherits_from_encoder_teacher(self): + """Verify EMATeacher inherits from EncoderTeacher.""" + from weathergen.train.target_and_aux_ssl_teacher import ( + EMATeacher, + EncoderTeacher, + ) + + assert issubclass(EMATeacher, EncoderTeacher) + + def test_frozen_teacher_inherits_from_encoder_teacher(self): + """Verify FrozenTeacher inherits from EncoderTeacher.""" + from weathergen.train.target_and_aux_ssl_teacher import ( + EncoderTeacher, + FrozenTeacher, + ) + + assert issubclass(FrozenTeacher, EncoderTeacher) + + def test_encoder_teacher_has_forward_teacher_method(self): + """Verify EncoderTeacher has _forward_teacher method.""" + from weathergen.train.target_and_aux_ssl_teacher import EncoderTeacher + + assert hasattr(EncoderTeacher, "_forward_teacher") + + +# ============================================================================= +# Validation and Error Handling Tests +# ============================================================================= + + +class TestValidationAndErrorHandling: + """Tests for input validation and error handling.""" + + def test_ema_teacher_rejects_zero_batch_size( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """EMATeacher should reject batch_size <= 0.""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + with pytest.raises(ValueError, match="batch_size must be positive"): + EMATeacher( + simple_model, mock_ema_model, batch_size=0, training_cfg=mock_training_cfg + ) + + def test_ema_teacher_rejects_negative_batch_size( + self, simple_model, mock_ema_model, mock_training_cfg + ): + """EMATeacher should reject negative batch_size.""" + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + with pytest.raises(ValueError, match="batch_size must be positive"): + EMATeacher( + simple_model, mock_ema_model, batch_size=-5, training_cfg=mock_training_cfg + ) + + def test_encoder_teacher_rejects_config_without_ssl_losses(self, simple_model): + """EncoderTeacher should reject config with no LossLatentSSLStudentTeacher.""" + from omegaconf import OmegaConf + + from weathergen.train.target_and_aux_ssl_teacher import EMATeacher + + # Config with only physical loss, no SSL + cfg = OmegaConf.create( + { + "losses": { + "physical_loss": { + "type": "LossPhysical", + "weight": 1.0, + } + } + } + ) + + mock_ema = MagicMock() + mock_ema.reset = MagicMock() + mock_ema.update = MagicMock() + mock_ema.forward_eval = MagicMock() + + with pytest.raises(ValueError, match="LossLatentSSLStudentTeacher"): + EMATeacher(simple_model, mock_ema, batch_size=8, training_cfg=cfg) + + def test_frozen_teacher_handles_malformed_config_gracefully(self, model_without_latent_heads): + """FrozenTeacher should handle config without 'losses' attribute.""" + from omegaconf import OmegaConf + + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + # Config without 'losses' key + cfg = OmegaConf.create({"some_other_key": "value"}) + + # Should not raise, should default to JEPA + teacher = FrozenTeacher(model_without_latent_heads, training_cfg=cfg) + assert "JEPA" in teacher.postprocess_targets + + def test_frozen_teacher_from_pretrained_rejects_none_run_id(self): + """from_pretrained should reject None teacher_run_id with helpful message.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + with pytest.raises(ValueError, match="teacher_run_id"): + FrozenTeacher.from_pretrained( + cf=MagicMock(get=lambda k: "/some/path", training_config=None), + dataset=MagicMock(), + device="cpu", + params={}, # Missing teacher_run_id + ) + + def test_frozen_teacher_from_pretrained_rejects_empty_run_id(self): + """from_pretrained should reject empty string teacher_run_id.""" + from weathergen.train.target_and_aux_ssl_teacher import FrozenTeacher + + with pytest.raises(ValueError, match="non-empty string"): + FrozenTeacher.from_pretrained( + cf=MagicMock(get=lambda k: "/some/path", training_config=None), + dataset=MagicMock(), + device="cpu", + params={"teacher_run_id": ""}, + ) + + def test_get_target_postprocessing_validates_dino_config(self): + """get_target_postprocessing should raise KeyError for missing DINO config.""" + from omegaconf import OmegaConf + + from weathergen.train.target_and_aux_ssl_teacher import get_target_postprocessing + + # DINO config missing required keys + incomplete_config = OmegaConf.create({"DINO": {"out_dim": 256}}) # Missing other keys + + with pytest.raises(KeyError, match="DINO loss config missing required keys"): + get_target_postprocessing(incomplete_config, training_cfg=None) + + def test_get_target_postprocessing_validates_ibot_config(self): + """get_target_postprocessing should raise KeyError for missing iBOT config.""" + from omegaconf import OmegaConf + + from weathergen.train.target_and_aux_ssl_teacher import get_target_postprocessing + + # iBOT config missing required keys + incomplete_config = OmegaConf.create({"iBOT": {"out_dim": 256}}) # Missing other keys + + with pytest.raises(KeyError, match="iBOT loss config missing required keys"): + get_target_postprocessing(incomplete_config, training_cfg=None) diff --git a/tests/test_layer_scale.py b/tests/test_layer_scale.py new file mode 100644 index 000000000..a62002ac1 --- /dev/null +++ b/tests/test_layer_scale.py @@ -0,0 +1,350 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Unit tests for LayerScale and StochasticDepth modules.""" + +import pytest +import torch + +from weathergen.model.layers import LayerScale, MLP, StochasticDepth + + +class TestLayerScale: + """Tests for the LayerScale module.""" + + def test_init_value(self): + """Test that gamma is initialized to the specified value.""" + dim = 64 + init_value = 1e-5 + layer_scale = LayerScale(dim, init_value) + + assert layer_scale.gamma.shape == (dim,) + assert torch.allclose(layer_scale.gamma, torch.full((dim,), init_value)) + + def test_init_value_rezero(self): + """Test ReZero initialization (init_value=0).""" + dim = 64 + layer_scale = LayerScale(dim, init_value=0.0) + + assert torch.allclose(layer_scale.gamma, torch.zeros(dim)) + + def test_forward_scaling(self): + """Test that forward applies per-channel scaling.""" + dim = 64 + batch_size = 8 + seq_len = 16 + init_value = 0.5 + + layer_scale = LayerScale(dim, init_value) + x = torch.randn(batch_size, seq_len, dim) + + out = layer_scale(x) + + expected = x * init_value + assert torch.allclose(out, expected) + + def test_forward_with_learned_gamma(self): + """Test forward with modified gamma values.""" + dim = 64 + layer_scale = LayerScale(dim, init_value=1.0) + + # Modify gamma + with torch.no_grad(): + layer_scale.gamma.fill_(2.0) + + x = torch.randn(8, 16, dim) + out = layer_scale(x) + + expected = x * 2.0 + assert torch.allclose(out, expected) + + def test_gradient_flow(self): + """Test that gradients flow through LayerScale.""" + dim = 64 + layer_scale = LayerScale(dim, init_value=1e-5) + x = torch.randn(8, 16, dim, requires_grad=True) + + out = layer_scale(x) + loss = out.sum() + loss.backward() + + assert x.grad is not None + assert layer_scale.gamma.grad is not None + + def test_output_shape(self): + """Test that output shape matches input shape.""" + dim = 64 + layer_scale = LayerScale(dim, init_value=1e-5) + + for shape in [(8, dim), (8, 16, dim), (8, 16, 32, dim)]: + x = torch.randn(*shape) + out = layer_scale(x) + assert out.shape == x.shape + + +class TestStochasticDepth: + """Tests for the StochasticDepth module.""" + + def test_init(self): + """Test initialization with drop probability.""" + drop_prob = 0.1 + sd = StochasticDepth(drop_prob) + assert sd.drop_prob == drop_prob + + def test_eval_mode_no_drop(self): + """Test that eval mode never drops (identity).""" + drop_prob = 0.9 # High drop prob + sd = StochasticDepth(drop_prob) + sd.eval() + + x = torch.randn(8, 16, 64) + out = sd(x) + + assert torch.equal(out, x) + + def test_train_mode_zero_prob(self): + """Test that zero drop probability is identity in train mode.""" + sd = StochasticDepth(drop_prob=0.0) + sd.train() + + x = torch.randn(8, 16, 64) + out = sd(x) + + assert torch.equal(out, x) + + def test_train_mode_high_prob(self): + """Test that very high drop probability drops most samples in train mode.""" + sd = StochasticDepth(drop_prob=0.99) + sd.train() + + torch.manual_seed(42) + x = torch.ones(100, 16, 64) + out = sd(x) + + # With 99% drop, most samples should be zero + zero_samples = (out.sum(dim=(1, 2)) == 0).sum().item() + assert zero_samples > 90 # At least 90 out of 100 should be dropped + + def test_expected_value_preservation(self): + """Test that expected value is preserved during training.""" + drop_prob = 0.3 + sd = StochasticDepth(drop_prob) + sd.train() + + torch.manual_seed(42) + x = torch.ones(1000, 16, 64) + + # Run many times to average + outputs = [] + for _ in range(1000): + outputs.append(sd(x).mean().item()) + + mean_output = sum(outputs) / len(outputs) + # Expected value should be approximately 1.0 (the input value) + assert abs(mean_output - 1.0) < 0.1 # Allow 10% tolerance + + def test_per_sample_dropping(self): + """Test that dropping is per-sample in batch dimension.""" + drop_prob = 0.5 + sd = StochasticDepth(drop_prob) + sd.train() + + torch.manual_seed(42) + batch_size = 100 + x = torch.ones(batch_size, 16, 64) + + out = sd(x) + + # Check that samples are either scaled or zero + sample_sums = out.sum(dim=(1, 2)) + expected_sum_scaled = 16 * 64 / (1 - drop_prob) + + for s in sample_sums: + # Each sample should be either 0 or scaled + assert s.item() == 0.0 or abs(s.item() - expected_sum_scaled) < 1e-4 + + def test_gradient_flow(self): + """Test that gradients flow through StochasticDepth.""" + sd = StochasticDepth(drop_prob=0.5) + sd.train() + + torch.manual_seed(42) # Ensure some samples are kept + x = torch.randn(8, 16, 64, requires_grad=True) + + out = sd(x) + loss = out.sum() + loss.backward() + + # Gradient should exist for kept samples + assert x.grad is not None + + def test_output_shape(self): + """Test that output shape matches input shape.""" + sd = StochasticDepth(drop_prob=0.5) + sd.train() + + for shape in [(8, 64), (8, 16, 64), (8, 16, 32, 64)]: + x = torch.randn(*shape) + out = sd(x) + assert out.shape == x.shape + + +class TestMLPWithLayerScaleAndStochasticDepth: + """Integration tests for MLP with LayerScale and StochasticDepth.""" + + def test_mlp_with_layer_scale(self): + """Test MLP with LayerScale enabled.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + layer_scale_init=1e-5, + ) + + assert mlp.layer_scale is not None + assert isinstance(mlp.layer_scale, LayerScale) + + x = torch.randn(8, 16, 64) + out = mlp(x) + + assert out.shape == x.shape + + def test_mlp_with_stochastic_depth(self): + """Test MLP with StochasticDepth enabled.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + stochastic_depth_rate=0.1, + ) + + assert mlp.drop_path is not None + assert isinstance(mlp.drop_path, StochasticDepth) + + mlp.train() + x = torch.randn(8, 16, 64) + out = mlp(x) + + assert out.shape == x.shape + + def test_mlp_with_both(self): + """Test MLP with both LayerScale and StochasticDepth.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + layer_scale_init=1e-5, + stochastic_depth_rate=0.1, + ) + + assert mlp.layer_scale is not None + assert mlp.drop_path is not None + + mlp.train() + x = torch.randn(8, 16, 64) + out = mlp(x) + + assert out.shape == x.shape + + def test_mlp_without_features(self): + """Test MLP with neither feature (backward compatibility).""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + ) + + assert mlp.layer_scale is None + assert mlp.drop_path is None + + x = torch.randn(8, 16, 64) + out = mlp(x) + + assert out.shape == x.shape + + def test_mlp_layer_scale_in_state_dict(self): + """Test that LayerScale parameters appear in state_dict.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + layer_scale_init=1e-5, + ) + + state_dict = mlp.state_dict() + assert "layer_scale.gamma" in state_dict + + def test_mlp_gradient_flow_with_features(self): + """Test gradient flow through MLP with LayerScale and StochasticDepth.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + layer_scale_init=1e-5, + stochastic_depth_rate=0.1, + ) + mlp.train() + + torch.manual_seed(42) + x = torch.randn(8, 16, 64, requires_grad=True) + + out = mlp(x) + loss = out.sum() + loss.backward() + + assert x.grad is not None + assert mlp.layer_scale.gamma.grad is not None + + +class TestReZero: + """Tests specifically for ReZero initialization (layer_scale_init=0).""" + + def test_rezero_initial_output(self): + """Test that ReZero initially outputs just the residual.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + layer_scale_init=0.0, # ReZero + ) + + x = torch.randn(8, 16, 64) + out = mlp(x) + + # With ReZero, initial output should be approximately equal to input + # (since layer_scale starts at 0, the layer contribution is 0) + assert torch.allclose(out, x, atol=1e-5) + + def test_rezero_gradual_learning(self): + """Test that ReZero allows gradual learning of layer scale.""" + mlp = MLP( + dim_in=64, + dim_out=64, + with_residual=True, + layer_scale_init=0.0, + ) + + # Initially gamma is 0 + assert torch.allclose(mlp.layer_scale.gamma, torch.zeros(64)) + + # After gradient update, gamma should change + x = torch.randn(8, 16, 64) + target = torch.randn(8, 16, 64) + + optimizer = torch.optim.SGD(mlp.parameters(), lr=0.1) + + for _ in range(10): + optimizer.zero_grad() + out = mlp(x) + loss = ((out - target) ** 2).mean() + loss.backward() + optimizer.step() + + # Gamma should now be non-zero + assert not torch.allclose(mlp.layer_scale.gamma, torch.zeros(64)) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 000000000..fd24508a4 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,499 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Tests for the optimizer module.""" + +import pytest +import torch +import torch.nn as nn +from omegaconf import OmegaConf + +from weathergen.train.optimizer import ( + ADAMW_PATTERNS, + CompositeOptimizer, + MuonCustom, + classify_muon_params, + create_optimizer, +) + + +class DummyTransformerBlock(nn.Module): + """Simple transformer-like model for testing parameter classification.""" + + def __init__(self, dim: int = 64, num_heads: int = 4): + super().__init__() + self.dim = dim + self.num_heads = num_heads + + # Attention components (should be Muon-eligible) + self.proj_heads_q = nn.Linear(dim, dim, bias=False) + self.proj_heads_k = nn.Linear(dim, dim, bias=False) + self.proj_heads_v = nn.Linear(dim, dim, bias=False) + self.proj_out = nn.Linear(dim, dim, bias=False) + + # MLP components (should be Muon-eligible) + self.mlp_fc1 = nn.Linear(dim, dim * 4, bias=False) + self.mlp_fc2 = nn.Linear(dim * 4, dim, bias=False) + + # Embeddings (should be AdamW) + self.embed_target_coords = nn.Linear(3, dim, bias=False) + self.embeds = nn.Embedding(100, dim) + + # Prediction heads (should be AdamW) + self.pred_heads = nn.Linear(dim, 10, bias=False) + self.latent_heads = nn.Linear(dim, dim, bias=False) + + # Biases and norms (should be AdamW) + self.bias = nn.Parameter(torch.zeros(dim)) + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + return x + + +class SimpleMLP(nn.Module): + """Simple MLP for testing optimizer steps.""" + + def __init__(self, input_dim: int = 10, hidden_dim: int = 32, output_dim: int = 5): + super().__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + self.embed = nn.Embedding(100, hidden_dim) + + def forward(self, x): + return self.fc2(torch.relu(self.fc1(x))) + + +@pytest.fixture +def dummy_model(): + """Create a dummy transformer model for testing.""" + return DummyTransformerBlock(dim=64, num_heads=4) + + +@pytest.fixture +def simple_model(): + """Create a simple MLP model for testing optimizer steps.""" + return SimpleMLP(input_dim=10, hidden_dim=32, output_dim=5) + + +@pytest.fixture +def optimizer_cfg(): + """Create a standard optimizer config.""" + return OmegaConf.create({ + "type": "adamw", + "grad_clip": 1.0, + "weight_decay": 0.1, + "adamw": { + "beta1": 0.975, + "beta2": 0.9875, + "eps": 2e-08, + }, + "muon": { + "lr_multiplier": 20.0, + "momentum": 0.95, + "nesterov": True, + "weight_decay": 0.1, + }, + }) + + +@pytest.fixture +def lr_cfg(): + """Create a standard LR config.""" + return OmegaConf.create({ + "lr_start": 1e-6, + "lr_max": 5e-5, + }) + + +class TestClassifyMuonParams: + """Tests for the classify_muon_params function.""" + + def test_classification_separates_params(self, dummy_model): + """Test that parameters are correctly separated into Muon and AdamW groups.""" + muon_params, adamw_params, muon_names, adamw_names = classify_muon_params(dummy_model) + + # Check that all trainable params are classified + total_params = sum(1 for p in dummy_model.parameters() if p.requires_grad) + assert len(muon_params) + len(adamw_params) == total_params + + # Check names match params count + assert len(muon_params) == len(muon_names) + assert len(adamw_params) == len(adamw_names) + + def test_attention_weights_are_muon(self, dummy_model): + """Test that attention Q/K/V/O weights are classified as Muon-eligible.""" + _, _, muon_names, _ = classify_muon_params(dummy_model) + + # These should be in Muon group + expected_muon = ["proj_heads_q", "proj_heads_k", "proj_heads_v", "proj_out"] + for name in expected_muon: + assert any(name in muon_name for muon_name in muon_names), f"{name} should be Muon" + + def test_mlp_weights_are_muon(self, dummy_model): + """Test that MLP linear weights are classified as Muon-eligible.""" + _, _, muon_names, _ = classify_muon_params(dummy_model) + + # MLP weights should be Muon + assert any("mlp_fc1" in name for name in muon_names) + assert any("mlp_fc2" in name for name in muon_names) + + def test_embeddings_are_adamw(self, dummy_model): + """Test that embedding parameters are classified as AdamW-eligible.""" + _, _, _, adamw_names = classify_muon_params(dummy_model) + + # These should be in AdamW group + expected_adamw = ["embed_target_coords", "embeds"] + for name in expected_adamw: + assert any(name in adamw_name for adamw_name in adamw_names), f"{name} should be AdamW" + + def test_pred_heads_are_adamw(self, dummy_model): + """Test that prediction heads are classified as AdamW-eligible.""" + _, _, _, adamw_names = classify_muon_params(dummy_model) + + assert any("pred_heads" in name for name in adamw_names) + assert any("latent_heads" in name for name in adamw_names) + + def test_1d_params_are_adamw(self, dummy_model): + """Test that 1D parameters (biases, norm weights) are AdamW-eligible.""" + _, adamw_params, _, adamw_names = classify_muon_params(dummy_model) + + # Check that bias and norm params are in AdamW + assert any("bias" in name for name in adamw_names) + assert any("norm" in name for name in adamw_names) + + # All 1D params should be in AdamW + for param in adamw_params: + if param.ndim < 2: + assert True # 1D params are correctly in AdamW + + def test_frozen_params_excluded(self, dummy_model): + """Test that frozen parameters are excluded from classification.""" + # Freeze some parameters + dummy_model.proj_heads_q.weight.requires_grad = False + dummy_model.embed_target_coords.weight.requires_grad = False + + muon_params, adamw_params, muon_names, adamw_names = classify_muon_params(dummy_model) + + # Frozen params should not appear + assert "proj_heads_q.weight" not in muon_names + assert "embed_target_coords.weight" not in adamw_names + + # Total should be reduced + total_trainable = sum(1 for p in dummy_model.parameters() if p.requires_grad) + assert len(muon_params) + len(adamw_params) == total_trainable + + +class TestCreateOptimizer: + """Tests for the create_optimizer factory function.""" + + def test_creates_adamw_by_default(self, simple_model, optimizer_cfg, lr_cfg): + """Test that AdamW is created when type is 'adamw'.""" + optimizer_cfg.type = "adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + assert isinstance(optimizer, torch.optim.AdamW) + + def test_creates_composite_for_muon_adamw(self, simple_model, optimizer_cfg, lr_cfg): + """Test that CompositeOptimizer is created when type is 'muon_adamw'.""" + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + assert isinstance(optimizer, CompositeOptimizer) + + def test_raises_for_unknown_type(self, simple_model, optimizer_cfg, lr_cfg): + """Test that unknown optimizer type raises ValueError.""" + optimizer_cfg.type = "unknown" + + with pytest.raises(ValueError, match="Unknown optimizer type"): + create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + def test_batch_size_scaling(self, simple_model, optimizer_cfg, lr_cfg): + """Test that betas are scaled based on batch size.""" + optimizer_cfg.type = "adamw" + + opt_small = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=1) + opt_large = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=16) + + # Larger batch should have different betas (closer to target) + beta1_small = opt_small.param_groups[0]["betas"][0] + beta1_large = opt_large.param_groups[0]["betas"][0] + + # With larger batch, beta1 should be smaller (more momentum decay) + assert beta1_large < beta1_small + + +class TestCompositeOptimizer: + """Tests for the CompositeOptimizer class.""" + + def test_step_updates_both_optimizers(self, simple_model, optimizer_cfg, lr_cfg): + """Test that step() updates parameters from both optimizers.""" + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + # Create dummy input and compute loss + x = torch.randn(4, 10) + output = simple_model(x) + loss = output.sum() + + # Store initial params + initial_params = {name: p.clone() for name, p in simple_model.named_parameters()} + + # Backward and step + loss.backward() + optimizer.step() + + # Check that params changed + params_changed = False + for name, p in simple_model.named_parameters(): + if not torch.equal(p, initial_params[name]): + params_changed = True + break + + assert params_changed + + def test_zero_grad_clears_both(self, simple_model, optimizer_cfg, lr_cfg): + """Test that zero_grad() clears gradients from both optimizers.""" + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + # Create gradients + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + + # Verify grads exist + has_grads = any(p.grad is not None for p in simple_model.parameters()) + assert has_grads + + # Zero grads + optimizer.zero_grad() + + # Verify grads are cleared + for p in simple_model.parameters(): + assert p.grad is None or p.grad.abs().sum() == 0 + + def test_state_dict_roundtrip(self, simple_model, optimizer_cfg, lr_cfg): + """Test that state dict can be saved and loaded.""" + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + # Take a step to populate state + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + optimizer.step() + + # Save state + state_dict = optimizer.state_dict() + + # Verify structure + assert "optimizer_type" in state_dict + assert state_dict["optimizer_type"] == "composite_muon_adamw" + assert "muon" in state_dict + assert "adamw" in state_dict + assert "muon_lr_multiplier" in state_dict + + # Create new optimizer and load state + optimizer2 = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + optimizer2.load_state_dict(state_dict) + + # Take another step - should not raise + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + optimizer2.step() + + def test_param_groups_combined(self, simple_model, optimizer_cfg, lr_cfg): + """Test that param_groups contains groups from both optimizers.""" + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + # Should have groups from both Muon and AdamW + assert len(optimizer.param_groups) >= 2 + + # Check that is_muon flag exists + has_muon_group = any(g.get("is_muon", False) for g in optimizer.param_groups) + has_adamw_group = any(not g.get("is_muon", True) for g in optimizer.param_groups) + + assert has_muon_group + assert has_adamw_group + + +class TestMuonCustom: + """Tests for the custom Muon optimizer implementation.""" + + def test_step_updates_params(self, simple_model): + """Test that Muon step updates parameters.""" + # Get only 2D params that will have gradients (fc1, fc2 weights) + # Exclude embedding since it's not used in the forward pass + params = [ + p for name, p in simple_model.named_parameters() + if p.ndim >= 2 and "embed" not in name + ] + optimizer = MuonCustom(params, lr=0.01, momentum=0.95) + + # Create dummy gradients + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + + # Store initial values + initial_values = [p.clone() for p in params] + + # Step + optimizer.step() + + # Check params with gradients changed + for i, p in enumerate(params): + if p.grad is not None: + assert not torch.equal(p, initial_values[i]), f"Param {i} was not updated" + + def test_momentum_buffer_created(self, simple_model): + """Test that momentum buffer is created after first step.""" + # Get params that will have gradients + params = [ + p for name, p in simple_model.named_parameters() + if p.ndim >= 2 and "embed" not in name + ] + optimizer = MuonCustom(params, lr=0.01, momentum=0.95) + + # Initially no state + assert all(len(optimizer.state[p]) == 0 for p in params) + + # Create gradients and step + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + optimizer.step() + + # Now should have momentum buffer for params with gradients + for p in params: + if p.grad is not None: + assert "momentum_buffer" in optimizer.state[p] + + def test_weight_decay_applied(self): + """Test that weight decay is applied to parameters.""" + # Simple 2D parameter + param = nn.Parameter(torch.ones(4, 4)) + optimizer = MuonCustom([param], lr=0.1, momentum=0.0, weight_decay=0.1) + + # Set gradient to zero (only weight decay should affect) + param.grad = torch.zeros_like(param) + + initial_norm = param.norm().item() + optimizer.step() + final_norm = param.norm().item() + + # Weight decay should reduce norm (since grad=0, only decay acts) + assert final_norm < initial_norm + + def test_nesterov_momentum(self): + """Test that Nesterov momentum produces different results than standard momentum.""" + torch.manual_seed(42) + + # Create two identical params + param1 = nn.Parameter(torch.randn(4, 4)) + param2 = nn.Parameter(param1.clone()) + + opt_standard = MuonCustom([param1], lr=0.1, momentum=0.9, nesterov=False) + opt_nesterov = MuonCustom([param2], lr=0.1, momentum=0.9, nesterov=True) + + # Same gradient + grad = torch.randn(4, 4) + param1.grad = grad.clone() + param2.grad = grad.clone() + + # Multiple steps + for _ in range(3): + opt_standard.step() + opt_nesterov.step() + param1.grad = grad.clone() + param2.grad = grad.clone() + + # Results should differ + assert not torch.allclose(param1, param2) + + +class TestLRSchedulerCompatibility: + """Tests for LR scheduler compatibility with CompositeOptimizer.""" + + def test_works_with_onecyclelr(self, simple_model, optimizer_cfg, lr_cfg): + """Test that CompositeOptimizer works with OneCycleLR scheduler.""" + from torch.optim.lr_scheduler import OneCycleLR + + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + # This should not raise TypeError (isinstance check) or ValueError (momentum check) + # CompositeOptimizer now has proper defaults with betas and momentum + scheduler = OneCycleLR( + optimizer, + max_lr=0.01, + total_steps=100, + cycle_momentum=True, # Default - requires betas or momentum in defaults + ) + + # Take a few steps + for _ in range(5): + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + def test_works_with_linearlr(self, simple_model, optimizer_cfg, lr_cfg): + """Test that CompositeOptimizer works with LinearLR scheduler.""" + from torch.optim.lr_scheduler import LinearLR + + optimizer_cfg.type = "muon_adamw" + optimizer = create_optimizer(simple_model, optimizer_cfg, lr_cfg, batch_size_total=4) + + # This should not raise TypeError + scheduler = LinearLR( + optimizer, + start_factor=0.1, + total_iters=100, + ) + + # Take a few steps + for _ in range(5): + x = torch.randn(4, 10) + loss = simple_model(x).sum() + loss.backward() + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + +class TestAdamWPatterns: + """Tests for the ADAMW_PATTERNS constant.""" + + def test_patterns_match_expected_names(self): + """Test that patterns match the expected parameter name patterns.""" + expected_patterns = [ + "embed_target_coords", + "embeds.", + "embed.", + "pred_heads", + "latent_heads", + "q_cells", + "bilin", + "norm", + "bias", + ] + + for pattern in expected_patterns: + assert pattern in ADAMW_PATTERNS + + def test_class_token_in_patterns(self): + """Test that class_token and register_token are in patterns.""" + assert "class_token" in ADAMW_PATTERNS + assert "register_token" in ADAMW_PATTERNS