Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
a0039ec
Log gradient norms
sophie-xhonneux Aug 6, 2025
e83903b
Prototype for recording grad norms
sophie-xhonneux Aug 6, 2025
d2995b4
Address review changes + hide behind feature flag
sophie-xhonneux Aug 7, 2025
26c6869
Final fixes including backward compatibility
sophie-xhonneux Aug 7, 2025
66da0d7
Merge branch 'develop' into sophiex/dev/log-grad-norms
sophie-xhonneux Aug 7, 2025
9a66f72
Ruff
sophie-xhonneux Aug 7, 2025
22a6fd7
More ruff stuff
sophie-xhonneux Aug 7, 2025
a1d7a27
Update to develop, prepare for new experiment series
MatKbauer Aug 12, 2025
128aeb1
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into mk/d…
MatKbauer Aug 12, 2025
d4be568
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into mk/d…
MatKbauer Aug 15, 2025
6504fc7
Rebase to latest develop
MatKbauer Sep 3, 2025
4f62e1a
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into mk/d…
MatKbauer Sep 8, 2025
87e7d3b
Merge branch 'develop' into jk/log-grad-norms/log-grad-norms
Jubeku Oct 7, 2025
754d31c
forecast config with small decoder
Jubeku Oct 8, 2025
cd7948f
Merge branch 'develop' into jk/log-grad-norms/log-grad-norms
Jubeku Oct 9, 2025
7c756a3
fixed uv.lock
Jubeku Oct 9, 2025
41716a6
test gradient logging on mutli gpus
Jubeku Oct 9, 2025
b5ce171
Update branch to latest develop with configured o48 settings
Oct 10, 2025
c12e190
Setting o48 as default in era5 config
Oct 10, 2025
d95277e
Updated default config to 256 dim latent size
MatKbauer Oct 10, 2025
a734471
Update branch to latest develop
MatKbauer Oct 13, 2025
3ae99dd
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into mk/d…
MatKbauer Oct 13, 2025
eba89a6
Change epochs from 64 to 32
MatKbauer Oct 13, 2025
5615634
LayerNorm replication and analysis tools
MatKbauer Nov 10, 2025
9ccc95e
Rename fe_layer_norm_at_layers to fe_layer_norm_after_blocks
MatKbauer Nov 10, 2025
240031d
Increase epochs from 32 to 64 and resolve minor bug
MatKbauer Nov 13, 2025
75b81fe
Merged gradient logging
MatKbauer Nov 15, 2025
f65ac37
Update to develop
MatKbauer Nov 17, 2025
20ae505
Update default_config back to d2048 on the O96 grid
MatKbauer Nov 17, 2025
2731d29
Update ERA5 stream to O96 grid
MatKbauer Nov 17, 2025
912c406
Update to latest develop having mini-epoch notation
MatKbauer Nov 17, 2025
028bb98
Resolving bug after merging with develop and updating default_config
MatKbauer Nov 18, 2025
ba84066
Enable loading old model checkpoints after recent merges
MatKbauer Nov 19, 2025
4f00cc6
Update WeatherGenReader with mini-epoch notation
MatKbauer Nov 19, 2025
e44e139
Minor modifications to latent histogram plotting
MatKbauer Nov 20, 2025
c979ab4
Resolve bug in histogram plotting
MatKbauer Nov 21, 2025
d24c4b6
Replace getattr by cf.get
MatKbauer Nov 21, 2025
89670bf
Change target read-out engine from 1 to 2 layers
MatKbauer Nov 24, 2025
58474b2
Set aux-info for fe-blocks to none
MatKbauer Nov 28, 2025
184dcd9
fix a plotting bug (#1453)
SavvasMel Dec 12, 2025
d3b63d2
Update train/val dates, HL=5, fsteps=2, lat-weighting
MatKbauer Dec 12, 2025
a584e41
Defined base config for parameter search
MatKbauer Dec 19, 2025
d7e75eb
Increase encoder/decoder size and add mlflow tags
MatKbauer Dec 20, 2025
9bc45c5
Added plot_train content
MatKbauer Dec 22, 2025
40b070e
Adam betas per cli and aifs channel weighting option
MatKbauer Dec 23, 2025
8c6dafb
Updated launch script
MatKbauer Jan 6, 2026
b5c98cb
Updated multiple job launch script with pre-training dropout and nois…
MatKbauer Jan 9, 2026
76cb62c
Updated experiment launch script
MatKbauer Jan 12, 2026
d14291c
added config and scripts for launching multiple runs
ankitpatnala Jan 19, 2026
42b4318
yml and multi scripts
ankitpatnala Jan 27, 2026
892f509
changed eval_config and inference script
ankitpatnala Jan 27, 2026
a44d5a4
added spike function to the fstep weighting
ankitpatnala Jan 29, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 51 additions & 19 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ embed_dropout_rate: 0.1

target_cell_local_prediction: True

ae_local_dim_embed: 1024
ae_local_num_blocks: 2
ae_local_dim_embed: 2048
ae_local_num_blocks: 0
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True
Expand All @@ -24,7 +24,7 @@ ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 8
ae_global_num_blocks: 4
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
Expand All @@ -33,6 +33,7 @@ ae_global_with_qk_lnorm: True
ae_global_att_dense_rate: 1.0
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2
ae_global_trailing_layer_norm: False

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
pred_adapter_kv: False
Expand All @@ -42,16 +43,18 @@ pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 0
forecast_offset : 1
forecast_delta_hrs: 0
forecast_steps: 0
forecast_policy: null
forecast_steps: 2
forecast_policy: "fixed"
forecast_freeze_model: False
forecast_att_dense_rate: 1.0
fe_num_blocks: 0
fe_num_blocks: 16
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
impute_latent_noise_std: 0.0 # 1e-4
fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer
impute_latent_noise_std: 1e-4

healpix_level: 5

Expand All @@ -77,7 +80,12 @@ loss_fcts_val:
-
- "mse"
- 1.0

timestep_weight: [spike_function,
{"type":"probability",
"values":{ 4 : 0.6,
6 : 0.2,
8 : 0.1,
10 : 0.1} ]
batch_size_per_gpu: 1
batch_size_validation_per_gpu: 1

Expand All @@ -93,15 +101,15 @@ ema_halflife_in_thousands: 1e-3

# training mode: "forecast" or "masking" (masked token modeling)
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
training_mode: "masking"
training_mode: "forecast"
# masking rate when training mode is "masking"; ignored in foreacast mode
masking_rate: 0.6
# sample the masking rate (with normal distribution centered at masking_rate)
# note that a sampled masking rate leads to varying requirements
masking_rate_sampling: True
# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream)
sampling_rate_target: 1.0
# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination"
# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "combination"
masking_strategy: "random"
# masking_strategy_config is a dictionary of additional parameters for the masking strategy
# required for "healpix" and "channel" masking strategies
Expand All @@ -113,21 +121,23 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"],
"same_strategy_per_batch": false
}

num_mini_epochs: 32
samples_per_mini_epoch: 4096
num_epochs: 128
samples_per_epoch: 4096
samples_per_validation: 512
shuffle: True

lr_scaling_policy: "sqrt"
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 1e-6
lr_max: 0.0001
lr_final_decay: 2e-6
lr_final: 0.0
lr_steps_warmup: 512
lr_steps_warmup: 256
lr_steps_cooldown: 512
lr_policy_warmup: "cosine"
lr_policy_decay: "constant"
lr_policy_cooldown: "linear"
adam_beta1: null # Becomes 0.8 with 2 nodes
adam_beta2: null # Becomes 0.9 with 2 nodes

grad_clip: 1.0
weight_decay: 0.1
Expand All @@ -136,9 +146,9 @@ nn_module: "te"
log_grad_norms: False

start_date: 197901010000
end_date: 202012310000
start_date_val: 202101010000
end_date_val: 202201010000
end_date: 202212310000
start_date_val: 202310010000
end_date_val: 202312310000
len_hrs: 6
step_hrs: 6
input_window_steps: 1
Expand All @@ -161,3 +171,25 @@ train_log_freq:
terminal: 10
metrics: 20
checkpoint: 250

# Tags for experiment tracking
# These tags will be logged in MLFlow along with completed runs for train, eval, val
# The tags are free-form, with the following rules:
# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries
# - tags should not duplicate existing config entries.
# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags
# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future)
wgtags:
# The name of the organization of the person running the experiment.
# This may be autofilled in the future. Expected values are lowercase strings of
# the organizations codenames in https://confluence.ecmwf.int/display/MAEL/Staff+Contact+List
# e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience"
org: None
issue: 1495
# The name of the experiment. This is a distinctive codename for the experiment campaign being run.
# This is expected to be the primary tag for comparing experiments in MLFlow.
# Expected values are lowercase strings with no spaces, just underscores:
# Examples: "rollout_ablation_grid"
exp: "rollout_params"
# *** Experiment-specific tags ***
grid_search: "dropout"
219 changes: 219 additions & 0 deletions config/eval_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
global_plotting_options:
image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" ..
dpi_val : 300
ERA5:
marker_size: 4

evaluation:
metrics : ["froct", "rmse"]
regions: ["global"]
summary_plots : true
summary_dir: "./plots/"
print_summary: false #print out score values on screen. it can be verbose
log_scale: false
add_grid: true

run_ids :

# lr=5e-4
#xs5l8zmj:
# label: "cosine scheduler lr_max=5e-4 v1"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

#x9zvml1k:
# label: "cosine scheduler lr_max=5e-4 v2"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

#an8rap5h:
# label: "cosine scheduler lr_max=5e-4 v3"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

# lr=1e-4
#u2qk39pi:
# label: "cosine scheduler lr_max=1e-4 v1 epoch=32"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

#zswipf53:
# label: "cosine scheduler lr_max=1e-4 v2 epoch=32"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

#dsdvzg59:
# label: "cosine scheduler lr_max=1e-4 v3 epoch=32"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

qc5dw7ki:
label: "cosine scheduler lr_max=1e-4 v1 epoch=48"
epoch: 0
rank: 0
streams:
ERA5:
channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
evaluation:
sample: "all"
forecast_step: "all"

oqe79vpk:
label: "cosine scheduler lr_max=1e-4 v2 epoch=48"
epoch: 0
rank: 0
streams:
ERA5:
channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
evaluation:
sample: "all"
forecast_step: "all"

hhblaokc:
label: "cosine scheduler lr_max=1e-4 v3 epoch=48"
epoch: 0
rank: 0
streams:
ERA5:
channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
evaluation:
sample: "all"
forecast_step: "all"
## lr=5e-5
#r812ji96:
# label: "cosine scheduler lr_max=5e-5 v1"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

#gj6eq2dx:
# label: "cosine scheduler lr_max=5e-5 v2"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

#ff80snum:
# label: "cosine scheduler lr_max=5e-5 v3"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

# ## lr=1e-5
#v0yha29i:
# label: "cosine scheduler lr_max=1e-5 v1"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

#cbmk73y0:
# label: "cosine scheduler lr_max=1e-5 v2"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

#ngdrjcbt:
# label: "cosine scheduler lr_max=1e-5 v3"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

## lr=5e-6
#voulcvsi:
# label: "cosine scheduler lr_max=5e-6 v1"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

#urlp39xq:
# label: "cosine scheduler lr_max=5e-6 v2"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

#ch1n05gd:
# label: "cosine scheduler lr_max=5e-6 v3"
# epoch: 0
# rank: 0
# streams:
# ERA5:
# channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"]
# evaluation:
# sample: "all"
# forecast_step: "all"

Loading
Loading