Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
505efde
Push to test
Feb 2, 2026
a0e5b38
Fix merge issue
sophie-xhonneux Feb 2, 2026
7e95bef
Claude fixing things
Feb 2, 2026
2e1bd76
Fixing Betas expected everywhere
Feb 2, 2026
84fa7d1
First commit
Feb 3, 2026
c98c746
use existing implementation
Feb 3, 2026
c2495b5
Add Layerscale etc to default config
sophie-xhonneux Feb 3, 2026
68aa2f4
Make JEPA default config for testing
sophie-xhonneux Feb 3, 2026
c0ce9dd
Add assert to prevent silent errors
Feb 4, 2026
aebe434
Merge branch 'develop' into sophiex/dev/test-layerscale-etc
Feb 4, 2026
71d2cce
Add collapse monitoring
Feb 4, 2026
1d29611
Fix bug
Feb 4, 2026
bc92ae7
Fix SVD computation failing
Feb 4, 2026
7693c19
Reduce variables logged
Feb 4, 2026
7f8de00
Fix EMA beta value computation
Feb 4, 2026
c3eb019
Refactor get_current_beta to ema.py
Feb 4, 2026
59a0a89
Sensible default for ema in jepa
sophie-xhonneux Feb 4, 2026
505331c
Merge branch 'sophiex/dev/monitor-collapse' into sophiex/dev/test-lay…
sophie-xhonneux Feb 4, 2026
4a091c8
New defaults
sophie-xhonneux Feb 5, 2026
32d951b
Implement Frozenteacher
Feb 6, 2026
3298252
Test config
sophie-xhonneux Feb 6, 2026
b4c46b1
Refactor frozen teacher creation
Feb 6, 2026
590d366
Fix stuff
Feb 6, 2026
64ae9f1
Fix
Feb 6, 2026
4444b04
Debug more
Feb 6, 2026
c3e52d0
Enable frozen models not trained with SSL
Feb 6, 2026
211f477
Improve code quality
Feb 6, 2026
491a69d
Test config
sophie-xhonneux Feb 6, 2026
08dbf6f
Update jepa config
sophie-xhonneux Feb 6, 2026
1133018
Try SALT training
sophie-xhonneux Feb 7, 2026
11b2e60
Fix model path loading
Feb 16, 2026
e08b7f8
Fix model_path
sophie-xhonneux Feb 16, 2026
b42a778
Fix inference corner case (#1818)
clessig Feb 6, 2026
8521b8c
fix latent_loss check in mode handling (#1784)
TillHae Feb 6, 2026
57c8518
Streamline run_train.py so it is suitable to be run both as a script …
grassesi Feb 6, 2026
424e188
Sgrasse/develop/435 unify dataset access (#1757)
grassesi Feb 6, 2026
141315b
Fix plot_train (#1831)
clessig Feb 13, 2026
b5b4a44
nse_metric (#1833)
jesicapinon Feb 16, 2026
f0d4a06
Random encoder fientuning on NPPATMS
sophie-xhonneux Feb 17, 2026
66394b5
Update configs to fix leak
sophie-xhonneux Feb 17, 2026
0884ae6
Plotting config
sophie-xhonneux Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions config/config_jepa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 2
ae_global_num_blocks: 0
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
Expand All @@ -37,7 +37,7 @@ ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2
ae_global_trailing_layer_norm: False

ae_aggregation_num_blocks: 8
ae_aggregation_num_blocks: 12
ae_aggregation_num_heads: 32
ae_aggregation_dropout_rate: 0.1
ae_aggregation_with_qk_lnorm: True
Expand Down Expand Up @@ -130,10 +130,33 @@ data_loading :

# config for training
training_config:

# training_mode: "masking", "student_teacher", "latent_loss"
training_mode: ["student_teacher"]

# Collapse monitoring for SSL training (JEPA/DINO/iBOT)
# Detects representation collapse via various metrics
collapse_monitoring:
enabled: true
compute_frequency: 100 # batches between metric computations
log_frequency: 100 # batches between metric logging
metrics:
effective_rank:
enabled: true
tensor_source: "both" # "student", "teacher", or "both"
sample_size: 2048 # max samples for SVD (0 = no sampling)
singular_values:
enabled: true
tensor_source: "both"
sample_size: 2048
dimension_variance:
enabled: true
tensor_source: "both" # cheap to compute, good early indicator
prototype_entropy:
enabled: true # only applies to DINO
ema_beta:
enabled: true

num_mini_epochs: 32
samples_per_mini_epoch: 4096
shuffle: True
Expand All @@ -148,25 +171,36 @@ training_config:

learning_rate_scheduling :
lr_start: 1e-6
lr_max: 5e-5
lr_max: 1e-4
lr_final_decay: 1e-6
lr_final: 0.0
num_steps_warmup: 512
num_steps_warmup: 4096
num_steps_cooldown: 512
policy_warmup: "cosine"
policy_decay: "constant"
policy_cooldown: "linear"
parallel_scaling_policy: "sqrt"

optimizer:
grad_clip: 1.0
weight_decay: 0.1
# Optimizer type: "adamw" (default) or "muon_adamw" (Muon for hidden weights, AdamW for embeddings/heads)
type: "muon_adamw"
grad_clip: 0.1
weight_decay: 0.05
log_grad_norms: False
adamw :
# parameters are scaled by number of DDP workers
beta1 : 0.975
beta2 : 0.9875
eps : 2e-08
muon:
# Learning rate multiplier for Muon relative to base LR (muon_lr = base_lr * lr_multiplier)
lr_multiplier: 30.0
# Momentum factor for Muon SGD
momentum: 0.95
# Use Nesterov momentum
nesterov: true
# Weight decay for Muon parameters (uses optimizer.weight_decay if not specified)
weight_decay: 0.05

losses : {
"student-teacher": {
Expand All @@ -179,16 +213,20 @@ training_config:
"num_blocks": 6, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 768,
"dropout_rate": 0.1,
target_source_correspondence: {0 : {0 : "subset"} },
},
},
},
target_and_aux_calc: { "EMATeacher" :
{ ema_ramp_up_ratio : 0.09,
ema_halflife_in_thousands: 1e-3,
model_param_overrides : {
training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}}
},
}
}
target_and_aux_calc: {FrozenTeacher: {
teacher_run_id: "yoqxf234", # "zosrc8ti", # Required
teacher_mini_epoch: -1}},
# },
# target_and_aux_calc: { "EMATeacher" :
# { ema_ramp_up_ratio : null,
# ema_halflife_in_thousands: 1e-1,
# model_param_overrides : {
# training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}}
# },
# }
# }
}
}

Expand Down
9 changes: 5 additions & 4 deletions config/config_jepa_finetuning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore
#####################################

# streams_directory: "./config/streams/era5_1deg/"
streams_directory: "./config/streams/era5_synop_finetuning/"
# streams_directory: "./config/streams/era5_synop_finetuning/"
streams_directory: "./config/streams/era5_nppatms_finetuning/"
streams: ???

general:
Expand Down Expand Up @@ -139,8 +140,8 @@ training_config:
samples_per_mini_epoch: 4096
shuffle: True

start_date: 1979-01-01T00:00
end_date: 2022-12-31T00:00
start_date: 2012-01-01T00:00
end_date: 2021-12-31T00:00

time_window_step: 06:00:00
time_window_len: 06:00:00
Expand Down Expand Up @@ -271,7 +272,7 @@ validation_config:
# write samples in normalized model space
normalized_samples: False,
# output streams to write; default all
streams: ["SurfaceCombined"],
streams: ["NPPATMS"],
}

# run validation before training starts (mainly for model development)
Expand Down
Loading