diff --git a/config/default_config.yml b/config/default_config.yml index 19c8769fb..f86c1a591 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -9,8 +9,8 @@ embed_dropout_rate: 0.1 target_cell_local_prediction: True -ae_local_dim_embed: 1024 -ae_local_num_blocks: 2 +ae_local_dim_embed: 2048 +ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 ae_local_with_qk_lnorm: True @@ -24,7 +24,7 @@ ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 ae_global_dim_embed: 2048 -ae_global_num_blocks: 8 +ae_global_num_blocks: 4 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True @@ -33,6 +33,7 @@ ae_global_with_qk_lnorm: True ae_global_att_dense_rate: 1.0 ae_global_block_factor: 64 ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning pred_adapter_kv: False @@ -42,16 +43,18 @@ pred_mlp_adaln: True # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -forecast_offset : 0 +forecast_offset : 1 forecast_delta_hrs: 0 -forecast_steps: 0 -forecast_policy: null +forecast_steps: 2 +forecast_policy: "fixed" +forecast_freeze_model: False forecast_att_dense_rate: 1.0 -fe_num_blocks: 0 +fe_num_blocks: 16 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True -impute_latent_noise_std: 0.0 # 1e-4 +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +impute_latent_noise_std: 1e-4 healpix_level: 5 @@ -77,7 +80,12 @@ loss_fcts_val: - - "mse" - 1.0 - +timestep_weight: [spike_function, + {"type":"probability", + "values":{ 4 : 0.6, + 6 : 0.2, + 8 : 0.1, + 10 : 0.1} ] batch_size_per_gpu: 1 batch_size_validation_per_gpu: 1 @@ -93,7 +101,7 @@ ema_halflife_in_thousands: 1e-3 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 -training_mode: "masking" +training_mode: "forecast" # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # sample the masking rate (with normal distribution centered at masking_rate) @@ -101,7 +109,7 @@ masking_rate: 0.6 masking_rate_sampling: True # sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream) sampling_rate_target: 1.0 -# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination" +# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "combination" masking_strategy: "random" # masking_strategy_config is a dictionary of additional parameters for the masking strategy # required for "healpix" and "channel" masking strategies @@ -113,21 +121,23 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], "same_strategy_per_batch": false } -num_mini_epochs: 32 -samples_per_mini_epoch: 4096 +num_epochs: 128 +samples_per_epoch: 4096 samples_per_validation: 512 shuffle: True lr_scaling_policy: "sqrt" lr_start: 1e-6 -lr_max: 5e-5 -lr_final_decay: 1e-6 +lr_max: 0.0001 +lr_final_decay: 2e-6 lr_final: 0.0 -lr_steps_warmup: 512 +lr_steps_warmup: 256 lr_steps_cooldown: 512 lr_policy_warmup: "cosine" lr_policy_decay: "constant" lr_policy_cooldown: "linear" +adam_beta1: null # Becomes 0.8 with 2 nodes +adam_beta2: null # Becomes 0.9 with 2 nodes grad_clip: 1.0 weight_decay: 0.1 @@ -136,9 +146,9 @@ nn_module: "te" log_grad_norms: False start_date: 197901010000 -end_date: 202012310000 -start_date_val: 202101010000 -end_date_val: 202201010000 +end_date: 202212310000 +start_date_val: 202310010000 +end_date_val: 202312310000 len_hrs: 6 step_hrs: 6 input_window_steps: 1 @@ -161,3 +171,25 @@ train_log_freq: terminal: 10 metrics: 20 checkpoint: 250 + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings of + # the organizations codenames in https://confluence.ecmwf.int/display/MAEL/Staff+Contact+List + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: None + issue: 1495 + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: "rollout_params" + # *** Experiment-specific tags *** + grid_search: "dropout" diff --git a/config/eval_config.yml b/config/eval_config.yml new file mode 100644 index 000000000..a8135c324 --- /dev/null +++ b/config/eval_config.yml @@ -0,0 +1,219 @@ +global_plotting_options: + image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. + dpi_val : 300 + ERA5: + marker_size: 4 + +evaluation: + metrics : ["froct", "rmse"] + regions: ["global"] + summary_plots : true + summary_dir: "./plots/" + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: true + +run_ids : + + # lr=5e-4 + #xs5l8zmj: + # label: "cosine scheduler lr_max=5e-4 v1" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #x9zvml1k: + # label: "cosine scheduler lr_max=5e-4 v2" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #an8rap5h: + # label: "cosine scheduler lr_max=5e-4 v3" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + # lr=1e-4 + #u2qk39pi: + # label: "cosine scheduler lr_max=1e-4 v1 epoch=32" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #zswipf53: + # label: "cosine scheduler lr_max=1e-4 v2 epoch=32" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #dsdvzg59: + # label: "cosine scheduler lr_max=1e-4 v3 epoch=32" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + qc5dw7ki: + label: "cosine scheduler lr_max=1e-4 v1 epoch=48" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + oqe79vpk: + label: "cosine scheduler lr_max=1e-4 v2 epoch=48" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + hhblaokc: + label: "cosine scheduler lr_max=1e-4 v3 epoch=48" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + ## lr=5e-5 + #r812ji96: + # label: "cosine scheduler lr_max=5e-5 v1" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #gj6eq2dx: + # label: "cosine scheduler lr_max=5e-5 v2" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #ff80snum: + # label: "cosine scheduler lr_max=5e-5 v3" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + # ## lr=1e-5 + #v0yha29i: + # label: "cosine scheduler lr_max=1e-5 v1" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #cbmk73y0: + # label: "cosine scheduler lr_max=1e-5 v2" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #ngdrjcbt: + # label: "cosine scheduler lr_max=1e-5 v3" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + ## lr=5e-6 + #voulcvsi: + # label: "cosine scheduler lr_max=5e-6 v1" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #urlp39xq: + # label: "cosine scheduler lr_max=5e-6 v2" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #ch1n05gd: + # label: "cosine scheduler lr_max=5e-6 v3" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + diff --git a/config/eval_config_matthias.yml b/config/eval_config_matthias.yml new file mode 100644 index 000000000..cbd200d66 --- /dev/null +++ b/config/eval_config_matthias.yml @@ -0,0 +1,115 @@ +global_plotting_options: + image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. + dpi_val : 300 + ERA5: + marker_size: 5 + +evaluation: + metrics : ["froct", "rmse"] + regions: ["global"] + summary_plots : true + summary_dir: "./plots/" + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: true + +run_ids : + + otnh7lcg: + label: "Control: fine-tune 1979-2022, 8x32 epochs, v1" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + v8bc9sxg: + label: "fine-tune 2018-2022, 8x32 epochs, v1" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + adja05fc: + label: "fine-tune 2018-2022, 8x32 epochs, v2" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + nub26n5i: + label: "fine-tune 2018-2022, seq [4, 6, 8] epochs, v1" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + kl5g92ne: + label: "fine-tune 2018-2022, seq [4, 6, 8] epochs, v2" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + bwzo5qfn: + label: "fine-tune 2018-2022, seq [4, 6, 8] epochs, v3" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + jrp2kgem: + label: "fine-tune 2018-2022, seq [3, 4, ..., 8] epochs, v1" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + ibmenr7o: + label: "fine-tune 2018-2022, seq [3, 4, ..., 8] epochs, v2" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + renlmg1i: + label: "fine-tune 2018-2022, seq [3, 4, ..., 8] epochs, v3" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" diff --git a/config/eval_config_seq.yml b/config/eval_config_seq.yml new file mode 100644 index 000000000..232dd0877 --- /dev/null +++ b/config/eval_config_seq.yml @@ -0,0 +1,35 @@ +global_plotting_options: + image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. + dpi_val : 300 + +evaluation: + metrics : ["froct", "rmse","acc"] #, "mae"] + regions: ["global"] + summary_plots : true + summary_dir: "./plots/" + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: false + score_cards: true + num_processes: auto + + +run_ids : + m9p1w7kx: + label: "8 steps all" + epoch: 0 + rank: 0 + streams: + ERA5: + climatology_path: /iopsstor/scratch/cscs/lessig/data/assets/climatology/aifs-ea-an-oper-0001-mars-o96-1980-2020-6h-v6_climatology.zarr + channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + #channels: ["2t", "q_850", ] + evaluation: + sample: "all" + forecast_step: "all" + plotting: + sample: [0] + forecast_step: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80 ] + plot_maps: false + plot_histograms: false + plot_animations: false diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml new file mode 100644 index 000000000..2d18fa589 --- /dev/null +++ b/config/runs_plot_train.yml @@ -0,0 +1,240 @@ +train : + plot : + # in9eslqf : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.05, v1" + # eval: vgbndhco + # t5vqafju : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.05, v2" + # eval: vgbndhco + # qz9n6815 : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.05, v3" + # eval: vgbndhco + + # scapqu18 : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.1, v1" + # eval: vgbndhco + # yqy2ezoa : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.1, v2" + # eval: vgbndhco + # bp9lcgwn : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.1, v3" + # eval: vgbndhco + + # vzeakjlb : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.2, v1" + # eval: vgbndhco + # a8n4zrfs : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.2, v2" + # eval: vgbndhco + # cxicf671 : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.2, v3" + # eval: vgbndhco + + # z3infogw : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.4, v1" + # eval: vgbndhco + # a9hp2qju : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.4, v2" + # eval: vgbndhco + # achvju39 : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.4, v3" + # eval: vgbndhco + + #h8l1yem5 : + # slurm_id: 0 + # description: "lr 5e-4, v2" + # eval: vgbndhco + #jeubm9ld : + # slurm_id: 0 + # description: "lr 5e-4, v3" + # eval: vgbndhco + #bezt6v8g : + # slurm_id: 0 + # description: "lr 5e-4, v1" + # eval: vgbndhco + # + #ya0gty48 : + # slurm_id: 0 + # description: "lr 1e-4, v1" + # eval: vgbndhco + #djiy1v3e : + # slurm_id: 0 + # description: "lr 1e-4, v2" + # eval: vgbndhco + #g96y1dq5 : + # slurm_id: 0 + # description: "lr 1e-4, v3" + # eval: vgbndhco + + + # mia69x1h : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.05, v1" + # eval: vgbndhco + # lgzkdwls : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.05, v2" + # eval: vgbndhco + # jr39znm6 : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.05, v3" + # eval: vgbndhco + + # gzxgp7cw : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.1, v1" + # eval: vgbndhco + # el6zytfd : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.1, v2" + # eval: vgbndhco + # c64w3cgy : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.1, v3" + # eval: vgbndhco + + # m4x3a0jt : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.2, v1" + # eval: vgbndhco + # manyrowd : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.2, v2" + # eval: vgbndhco + # ijwbpy3k : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.2, v3" + # eval: vgbndhco + + # i9qkv084 : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.4, v1" + # eval: vgbndhco + # l78tqy2z : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.4, v2" + # eval: vgbndhco + # xn4wa7b2 : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.4, v3" + # eval: vgbndhco + + #s9sldzyb : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.6, v1" + # eval: vgbndhco + #e29izt1j : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.6, v2" + # eval: vgbndhco + #bmoc645w : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.6, v3" + # eval: vgbndhco + + + + # q4l8jb2e : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.05, v1" + # eval: vgbndhco + # eytr9nki : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.05, v2" + # eval: vgbndhco + # bbcm27x1 : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.05, v3" + # eval: vgbndhco + + # jjbfpuya : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.1, v1" + # eval: vgbndhco + # wxehoqic : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.1, v2" + # eval: vgbndhco + # sbylixor : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.1, v3" + # eval: vgbndhco + + # scguorkl : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.2, v1" + # eval: vgbndhco + # uh0iz8sa : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.2, v2" + # eval: vgbndhco + # jizcxg9f : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.2, v3" + # eval: vgbndhco + + # d2wgjec9 : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.4, v1" + # eval: vgbndhco + # g10zvcn4 : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.4, v2" + # eval: vgbndhco + # a2vlj964 : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.4, v3" + # eval: vgbndhco + + #lvy8406i : + # slurm_id: 0 + # description: "lr 1e-4, cool_down_steps=512, v1" + # eval: vgbndhco + #ipn3jryk : + # slurm_id: 0 + # description: "lr 1e-4, cool_down_steps=512, v2" + # eval: vgbndhco + #wrucxsk6 : + # slurm_id: 0 + # description: "lr 1e-4, cool_down_steps=2048, v1" + # eval: vgbndhco + #dfvo0ir1 : + # slurm_id: 0 + # description: "lr 1e-4, cool_down_steps=2048, v2" + # eval: vgbndhco + czfrhdae : + slurm_id: 0 + description: "lr 1e-4, cosine_lr=32 epochs, v1" + eval: vgbndhco + zha9i6x3 : + slurm_id: 0 + description: "lr 1e-4, cosine_lr=32 epochs, v1" + eval: vgbndhco + ypr1b3a4 : + slurm_id: 0 + description: "lr 1e-4, cosine_lr=32 epochs, v1" + eval: vgbndhco + otn1u3oe : + slurm_id: 0 + description: "lr 1e-4, cosine_lr=48 epochs, v1" + eval: vgbndhco + r2z01faj : + slurm_id: 0 + description: "lr 1e-4, cosine_lr=48 epochs, v2" + eval: vgbndhco + xxjfcwq1 : + slurm_id: 0 + description: "lr 1e-4, cosine_lr=48 epochs,v2" + eval: vgbndhco diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index bb2234c4e..33b47d9bd 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -10,9 +10,11 @@ ERA5 : type : anemoi filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + # filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr'] source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] loss_weight : 1. + location_weight : cosine_latitude masking_rate : 0.6 masking_rate_none : 0.05 token_size : 8 @@ -22,11 +24,11 @@ ERA5 : net : transformer num_tokens : 1 num_heads : 8 - dim_embed : 256 + dim_embed : 512 num_blocks : 2 embed_target_coords : net : linear - dim_embed : 256 + dim_embed : 512 target_readout : type : 'obs_value' # token or obs_value num_layers : 2 diff --git a/config/streams/era5_1deg_w-aifs/era5.yml b/config/streams/era5_1deg_w-aifs/era5.yml new file mode 100644 index 000000000..69108aa11 --- /dev/null +++ b/config/streams/era5_1deg_w-aifs/era5.yml @@ -0,0 +1,105 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + # filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr'] + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + loss_weight : 1. + location_weight : cosine_latitude + channel_weights : + q_50: 0.2 + q_100: 0.23 + q_150: 0.26 + q_200: 0.29 + q_250: 0.33 + q_300: 0.36 + q_400: 0.42 + q_500: 0.48 + q_600: 0.55 + q_700: 0.61 + q_850: 0.71 + q_925: 0.75 + q_1000: 0.8 + t_50: 0.2 + t_100: 0.23 + t_150: 0.26 + t_200: 0.29 + t_250: 0.33 + t_300: 0.36 + t_400: 0.42 + t_500: 0.48 + t_600: 0.55 + t_700: 0.61 + t_850: 0.71 + t_925: 0.75 + t_1000: 0.8 + u_50: 0.2 + u_100: 0.23 + u_150: 0.26 + u_200: 0.29 + u_250: 0.33 + u_300: 0.36 + u_400: 0.42 + u_500: 0.48 + u_600: 0.55 + u_700: 0.61 + u_850: 0.71 + u_925: 0.75 + u_1000: 0.8 + v_50: 0.2 + v_100: 0.23 + v_150: 0.26 + v_200: 0.29 + v_250: 0.33 + v_300: 0.36 + v_400: 0.42 + v_500: 0.48 + v_600: 0.55 + v_700: 0.61 + v_850: 0.71 + v_925: 0.75 + v_1000: 0.8 + z_50: 0.2 + z_100: 0.23 + z_150: 0.26 + z_200: 0.29 + z_250: 0.33 + z_300: 0.36 + z_400: 0.42 + z_500: 0.48 + z_600: 0.55 + z_700: 0.61 + z_850: 0.71 + z_925: 0.75 + z_1000: 0.8 + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + type : 'obs_value' # token or obs_value + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/launch_multi.sh b/launch_multi.sh new file mode 100644 index 000000000..092e1d4f6 --- /dev/null +++ b/launch_multi.sh @@ -0,0 +1,101 @@ + +# GRID SEARCH 1 +# for lr in "5e-4" "1e-4" "5e-5" ; do +# for w_dec in 0.05 0.1 0.2 0.4 0.6 ; do +# echo "$lr $w_dec" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --chain-jobs 3 --link-venv --options lr_max=$lr weight_decay=$w_dec +# done +# done + +# # Resume training +# #for run_id in in9eslqf t5vqafju qz9n6815 scapqu18 yqy2ezoa bp9lcgwn vzeakjlb a8n4zrfs cxicf671 z3infogw a9hp2qju achvju39 cszpe803 y0kauh4s eneq4ahr ; do +# #for run_id in mia69x1h lgzkdwls jr39znm6 gzxgp7cw el6zytfd c64w3cgy m4x3a0jt manyrowd ijwbpy3k i9qkv084 l78tqy2z xn4wa7b2 s9sldzyb e29izt1j bmoc645w ; do +# for run_id in q4l8jb2e eytr9nki bbcm27x1 jjbfpuya wxehoqic sbylixor scguorkl uh0iz8sa jizcxg9f d2wgjec9 g10zvcn4 a2vlj964 z8vx03bg e3k2v450 qmil5gwk ; do +# echo "$run_id" +# #cp ../WeatherGenerator-private/hpc/santis/weathergen_slurm_train.sh /capstor/scratch/cscs/mkarlbau/slurm/slurm_weathergen_"$run_id"_dir/WeatherGenerator-private/hpc/santis/. +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv +# done + +# # Delete validation.zarr and plots directories +# # ww0r248v, pg5oe6rq, jc10pgys, usizx285, yqy2ezoa, xbgazwtc, bxf4bdlz, tmnrdyk9, ue5ky698, hd4mvfa1, s301cqla, m63ocvtq, fq2jposb, fh1drqz8, r2ut4clp, u16fvide, t790hb8a, d0bou6tk, ge9zmby0, c745lzyr, sevmrclb, rqer37pc, ot3hqr0x, kj2qxw9k, x18rkx3s, afe4cwb0, srkhuy4g, mph51qok, bh2z0jkt, b0lwy3rk, qehytran, eionpvqj, oo4hq36z, a1x2cdf0, dn13x6ql, c7c480k2, uot4snvp, p3kvrg9j, mcugwbsp, qiz2bfkv, solj81d4, ku4r3omn, kro1j69u, gfm9e1z6, njhycz89 +# for run_id in ww0r248v ; do +# echo "results/$run_id/validation_chkpt00000_rank0000.zarr" +# rm -r "results/$run_id/validation_chkpt00000_rank0000.zarr" +# rm -r "results/$run_id/plots" +# done + + +# GRID SEARCH 2 +# beta1 0.6 0.7 0.8 0.9 0.95 +# beta2 0.8 0.9 0.95, 0.99 +# streams_directory="./config/streams/era5_1deg, "./config/streams/era5_1deg_w-aifs"" +# for beta1 in 0.6 0.7 0.8 0.9 0.95 ; do +# for beta2 in 0.8 0.9 0.95 0.99 ; do +# for sd in "./config/streams/era5_1deg" "./config/streams/era5_1deg_w-aifs" ; do +# echo "$beta1 $beta2 $sd" +# # ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --chain-jobs 3 --link-venv --options lr_max=0.0001 weight_decay=0.1 adam_beta1=$beta1 adam_beta2=$beta2 streams_directory=$sd +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --chain-jobs 3 --link-venv --options lr_max=0.00005 weight_decay=0.05 adam_beta1=$beta1 adam_beta2=$beta2 streams_directory=$sd +# done +# done +# done + +# # Train continue grid search 2 +# # for run_id in s3cr2ef4 mrcw4e8h djuk9pzn aepv7wu1 v3h9k2bj ynlvs89z tz49cgl2 l98sz3jw m1k9eplh tgmftqi4 pu6e03ng lnph48zk t69trx2e j9eiz4dq z1sd0pnr tide3hv6 m45w7y1t np85vcnq ldv8n37w cfblam6s fb4ms5ec w584rjeg ctg1usw6 jz8tjv6m w4xqvt69 qrjo9xs0 u34liwsg exvmlypc uubfj3gz h392evts tp8b24cj oxfpqjw1 xpv8qnyd xpw5o9fx ti4vmpsg qkq9yjl7 e1enpxz3 nwj4z09v rhk794ou h1pwbi4a bcmvof51 b2weib8p hzd5uet7 k5pilh01 v4ah1kzx ppxufsjg pcl0snok wurb6xtk npmszy6v y9kfdpom x9ay3kmw irn3mjyc z4awp1g5 gmlzaqhj r0spjzik knouy46d q3tjv9di dwp6e7d9 n1qikpey j2dg98i0 zed62zhu a4vqxo9u suz4ilra m3vscjad mclsg8op rgtp5jy9 jnwrgzpf w9rkhg1v ob6r9moj t6ibtc1j mgyzswul z271maoe q57y4ve3 ohce4138 ao13xq8w t6bjgt9a x762dlfb xdo9plre hlrt72oe hbnm4u1x ; do +# # for run_id in z08sckyz k52e7hfo p09l6hpz gl3ev20k iaw2gbrn m6jniu34 obwocxim ep5q8gzn tkywi1hd lu3lj0z4 y75fha42 kmdx1g3n kmztan34 pg9k1fli qgv2layj heomgws2 yo8jyrbv jfd9lr0u x3rqmjes gk1ny5oh c0yf9pdt n2vm0sxq fvq6tyeo ql6spk9b p5ibvwda cp0k1an8 c01zraug hv97duk1 i2lrahz9 ge6v74ir e6ijt39q sx6m2ejs va9d6yv3 nw43c2n0 z1yi6s2c jsfhgdq1 d3o1a9lc u0gf6k3l gf3r9noc nb6342qh zyxq24l1 t5uh3gv9 hng8tw29 xd60xpu5 y90egrso jnsxo6mt lfq8om96 hhqnr8l9 sqne63ot unq1uwez cm3tolj1 k8irbcyk fqxcans2 ujqxld2f ksdn58ca wirp7n6q ttlvz8f2 xx68golp sm9i8lgx qeopqgju d2mj6l1d u1h675cf dge0wox1 gz6ad0sc g0hb53if n6nymxjc l1np8abi pybuh3o4 gn8jcdug vj6gn49y lnwequ80 t486blik il6f0p87 yixm5qls uafuiesw fh4vuolp ey17ezol im8g2pzi ff1ldwox wb47lhdj ; do +# # for run_id in eosc19jg kl0rcua6 ukwujlf4 kfg9vle7 gp5cnl91 d5fwxbil lzg3iprw qxdtz5m1 p9d1ya0q x03dmoey f1pezdro i0jh85s1 gw6leyu4 ev2zrdsu b9kzho1t smgpkdbj tmk1adys wt89w7kp kuzwmra5 ezimdn39 npeywgca qx9bgzqw z8qvl037 jv6am9p8 lokz5def fvef6dpn gw3fz7rc f7jn9ep8 nblsm3uo ze1mh2cb t5fsvqxh vwz3ue4g ez2iolsm t946rtp8 bh65jf8k zb7s32vu momerjg2 o45qt3oc z8vwrchb lfc1rl63 gxibzqeo jx4pw8jb icnlv2kw lx5c0qh4 e2ndef59 aeqf5ozl e0riq4o6 psq79jhc cgncym6p idpvqbaw qn52qavo hygrvmh4 n6azvsmo difz8odj s5o06pyi v450oj3h g4n6iu78 yqtgkaz3 ezl2c8od vcophb7i kl7gop09 uhj4npb6 yjql2371 e1ungd2z awn0a856 hxtscki8 rln8xs45 i6jkuig9 s8qetp41 vn903brm sjxehz2y ch2cnwf1 qo7aihzs v2djkyu0 irmnhejc jct7zx6j y3ezyh1x prx0wqsk iecmarv3 kr4gtp38 ; do +# for run_id in eosc19jg kl0rcua6 ukwujlf4 kfg9vle7 gp5cnl91 d5fwxbil lzg3iprw qxdtz5m1 p9d1ya0q x03dmoey f1pezdro i0jh85s1 gw6leyu4 ev2zrdsu b9kzho1t smgpkdbj tmk1adys wt89w7kp kuzwmra5 ezimdn39 npeywgca qx9bgzqw z8qvl037 jv6am9p8 lokz5def fvef6dpn gw3fz7rc f7jn9ep8 nblsm3uo ze1mh2cb t5fsvqxh vwz3ue4g ez2iolsm t946rtp8 bh65jf8k zb7s32vu momerjg2 o45qt3oc z8vwrchb lfc1rl63 gxibzqeo jx4pw8jb icnlv2kw lx5c0qh4 e2ndef59 aeqf5ozl e0riq4o6 psq79jhc cgncym6p idpvqbaw qn52qavo hygrvmh4 n6azvsmo difz8odj s5o06pyi v450oj3h g4n6iu78 yqtgkaz3 ezl2c8od vcophb7i kl7gop09 uhj4npb6 yjql2371 e1ungd2z awn0a856 hxtscki8 rln8xs45 i6jkuig9 s8qetp41 vn903brm sjxehz2y ch2cnwf1 qo7aihzs v2djkyu0 irmnhejc jct7zx6j y3ezyh1x prx0wqsk iecmarv3 kr4gtp38 ; do +# echo "$run_id" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv +# done + + +# # DROPOUT [0.05, 0.1, 0.15, 0.2, 0.3, 0.4] +# for dropout in 0.05 0.1 0.15 0.2 0.3 0.4 ; do +# echo "$dropout" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --link-venv --options lr_max=0.00005 weight_decay=0.05 adam_beta1=0.85 adam_beta2=0.9 streams_directory="./config/streams/era5_1deg_w-aifs" embed_dropout_rate=$dropout ae_local_dropout_rate=$dropout ae_adapter_dropout_rate=$dropout ae_global_dropout_rate=$dropout fe_dropout_rate=$dropout +# done + +# # Train continue dropout +# for run_id in d2f1p4vh m8e7psdl n2nmxc7b flaucoz5 dnl5r61x aojt3c1z p1phw3g9 kacy7jbz uk0uvcfn d1fhev63 pxg7jnzt z40dbxjy fpymqrv3 pe93az4w saxqsfzb yjlzi5g7 zewh2o5n dy36qb7e ; do +# echo "$run_id" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv +# done + + +# # NOISE LEVEL [1e-3, 5e-4, (1e-4), 5e-5, 1e-5] +# for nl in "1e-3" "5e-4" "1e-4" "5e-5" "1e-5" ; do +# echo "$nl" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --link-venv --options lr_max=0.00005 weight_decay=0.05 adam_beta1=0.85 adam_beta2=0.9 streams_directory="./config/streams/era5_1deg_w-aifs" impute_latent_noise_std=$nl +# done + +# # Train continue noise level +# for run_id in fmpesclt h6lu3sh8 tgmwaifc vavdy4zf qf24wjsq nbc3il5x vvwizau9 wyhcr51m n207tod4 cey23p7w rh49o7yj qt72d4iy vl5n39cj ocyx09uw fydmc3vg ; do +# echo "$run_id" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv +# done + + + #for run_id in fmpesclt vvwizau9 vl5n39cj ; do + for run_id in r2z01faj xxjfcwq1 ; do + echo "$run_id" + ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --time 24:00:00 --from-run-id $run_id --run-id $run_id --link-venv + done + + +#for run_id in bwtkj0he pji5hbze pu5ct7ox; do +#for run_id in lvy8406i ipn3jryk it0uzsl3 qm45twzj d3gc8fdn ; do +#for run_id in it0uzsl3 qm45twzj d3gc8fdn ; do +#for run_id in ipn3jryk d3gc8fdn ; do +# echo "$run_id" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --time 24:00:00 --link-venv --options num_epochs=36 lr_steps_cooldown=2048 +#done + + +# Cosine learning_rate test +#for lr_max in "5e-4" "1e-4" "5e-5" "1e-5" "5e-6" ; do +#for lr_max in "1e-4" ; do +# echo "$lr_max" +# for from_run_id in dnl5r61x vvwizau9 wyhcr51m; do +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --time 24:00:00 --from-run-id $from_run_id --link-venv --options istep=0 num_epochs=48 lr_max=$lr_max lr_policy_decay="cosine" forecast_steps=8 freeze_modules=".*global.*|.*local.*|.*adapter.*|.*ERA5.*" +#done +#done diff --git a/launch_multi_infer.sh b/launch_multi_infer.sh new file mode 100644 index 000000000..f3ff9bc85 --- /dev/null +++ b/launch_multi_infer.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# (from_run_id, to_run_id) tuples from mapping +run_pairs=( + #"czfrhdae u2qk39pi" + #"zha9i6x3 zswipf53" + #"ypr1b3a4 dsdvzg59" + #"vctfgruv r812ji96" + #"htdwqjpx gj6eq2dx" + #"gakr74pw ff80snum" + #"qpogewjf v0yha29i" + #"y3trwpx7 cbmk73y0" + #"asnz2gyl ngdrjcbt" + #"dd1cq6nv voulcvsi" + #"dn15vfks urlp39xq" + #"fl9xrpao ch1n05gd" +) + +#for tuple in "${run_pairs[@]}"; do +# read from_run_id run_id <<< "$tuple" +# echo "From: $from_run_id → Run_id: $run_id" +# sbatch weather_slurm_inferece.sh "$from_run_id" "$run_id" +#done + + + + + +#for run_id in unov2gdz pv5hu3mc exsm2wty czfrhdae zha9i6x3 ypr1b3a4 vctfgruv htdwjqpx gakr74pw qpogewjf y3trwpx7 asnz2gyl dd1cq6nv dn15vfks fl9xrpao ; do +#for run_id in xqbky3ht whsolnr7 e0yzx968 ; do +for run_id in otn1u3oe r2z01faj xxjfcwq1 ; do + echo $ "$run_id" + sbatch weather_slurm_inferece.sh "$run_id" +done diff --git a/launch_multi_lr.sh b/launch_multi_lr.sh new file mode 100644 index 000000000..85e31cdcc --- /dev/null +++ b/launch_multi_lr.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Extract all unique (lr, from_run_id) tuples from mapping +tuples=( + "5e-4 dnl5r61x" # zd4t0zmp + "5e-4 vvwizau9" # fzmgdsev + "5e-4 wyhcr51m" # rqhn8y14 + "5e-5 dnl5r61x" # nibpxofg + "5e-5 wyhcr51m" # zylxr8pm + "1e-5 dnl5r61x" # s1urb38z + "1e-5 vvwizau9" # lmix3abo + "1e-5 wyhcr51m" # gu20n5l8 + "5e-6 dnl5r61x" # sxlqdhue + "5e-6 vvwizau9" # s8pfvmle +) + +echo "Launching ${#tuples[@]} experiments..." + +for tuple in "${tuples[@]}"; do + read lr_max from_run_id <<< "$tuple" + echo "=== $from_run_id @ $lr_max ===" + + ../WeatherGenerator-private/hpc/launch-slurm.py \ + --nodes 2 \ + --time 24:00:00 \ + --from-run-id "$from_run_id" \ + --link-venv \ + --options istep=0 num_epochs=32 lr_max=$lr_max lr_policy_decay="cosine" forecast_steps=8 freeze_modules=".*global.*|.*local.*|.*adapter.*|.*ERA5.*" + + echo "----------------------------------------" +done + +echo "All $((${#tuples[@]})) jobs submitted!" + diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index 66fb2602d..fe12c58e5 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -469,7 +469,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non super().__init__(eval_cfg, run_id, private_paths) - self.mini_epoch = eval_cfg.mini_epoch + # TODO: remove backwards compatibility to "epoch" in Feb. 2026 + self.mini_epoch = getattr(eval_cfg, "mini_epoch", eval_cfg["epoch"]) self.rank = eval_cfg.rank # Load model configuration and set (run-id specific) directories @@ -889,7 +890,7 @@ def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | N """ score_path = ( Path(self.metrics_dir) - / f"{self.run_id}_{stream}_{region}_{metric}_epoch{self.epoch:05d}.json" + / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" ) _logger.debug(f"Looking for: {score_path}") diff --git a/packages/evaluate/src/weathergen/evaluate/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotter.py index cb15e6f24..007b957e7 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotter.py @@ -14,8 +14,8 @@ from matplotlib.lines import Line2D from PIL import Image from scipy.stats import wilcoxon - from weathergen.common.config import _load_private_conf + from weathergen.evaluate.plot_utils import ( DefaultMarkerSize, ) @@ -482,7 +482,7 @@ def scatter_plot( # TODO: make this nicer parts = ["map", self.run_id, tag] - if self.sample: + if self.sample is not None: parts.append(str(self.sample)) if "valid_time" in data.coords: diff --git a/pyproject.toml b/pyproject.toml index 0f0f7a296..7051bfa36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "weathergen-common", "weathergen-evaluate", "weathergen-readers-extra", + "pyyaml>=6.0.2", ] diff --git a/scripts/model_weight_progression.py b/scripts/model_weight_progression.py new file mode 100644 index 000000000..1af1aaaad --- /dev/null +++ b/scripts/model_weight_progression.py @@ -0,0 +1,82 @@ +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import torch +import tqdm + + +def load_checkpoint(run_id: str, epoch: int) -> dict[str, torch.Tensor]: + chkpt = torch.load(f"models/{run_id}/{run_id}_epoch{str(epoch).zfill(5)}.chkpt") + fe_keys = [key for key in list(chkpt.keys()) if "fe" in key] + fe_chkpt = {fe_key: chkpt[fe_key] for fe_key in fe_keys} + return fe_keys, fe_chkpt + + +def get_layer_and_name(key: str) -> [int, str]: + key_split = key.split(".") + layer = int(key_split[1]) + name = ".".join(key_split[2:]) + return layer, name + + +def prepare_weights_and_eigenvalues( + w_dict: dict[str, torch.Tensor], +) -> [dict[str, list], dict[str, list]]: + # Compute eigenvectors of each layer. Set to [0, 0] if no matrix. + e_dict = { + key: (w_dict[key].svd().S.cpu().numpy()) if len(w_dict[key].shape) > 1 else [0, 0] + for key in w_dict + } + # Flatten all weights + w_dict = {key: w_dict[key].flatten().cpu().numpy() for key in w_dict} + return w_dict, e_dict + + +def plot_results( + w_dict: dict[str, torch.Tensor], epoch: int, layers: int, run_id: str, plot_dir: str +): + w_dict, e_dict = prepare_weights_and_eigenvalues(w_dict=w_dict) + fig, axs = plt.subplots(2, 1, figsize=(len(w_dict.keys()), 5), sharex=True) + axs[0].boxplot(w_dict.values(), tick_labels=w_dict.keys()) + axs[1].violinplot(e_dict.values()) + axs[0].grid() + axs[1].grid() + axs[0].set_title("Weight distribution") + axs[1].set_title("Singular value distribution") + plt.xticks(rotation=45, ha="right") + os.makedirs(plot_dir, exist_ok=True) + plot_path = plot_dir / Path(f"w-dist_{run_id}_epoch{str(epoch).zfill(3)}.png") + fig.savefig(plot_path, bbox_inches="tight", pad_inches=0) + + +if __name__ == "__main__": + run_id = "vso7p6dt" + epochs = [2, 4, 8, 16, 32, 63] + plot_dir = Path("plots", "w_dist", run_id) + + for epoch in tqdm.tqdm(epochs, desc="Processing epoch"): + fe_keys, fe_w_dict = load_checkpoint(run_id=run_id, epoch=epoch) + + # + # Option 1: All layers in one plot + plot_results(w_dict=fe_w_dict, epoch=epoch, layers=15, run_id=run_id, plot_dir=plot_dir) + + # + # Option 2: One plot per layer + # layer = -1 # init + # for fe_key in fe_keys: + # l, name = get_layer_and_name(key=fe_key) + + # if layer != l: + # if layer != -1: + # plot_results(w_dict=w_per_layer, epoch=epoch, layers=15, run_id=run_id, plot_dir=plot_dir) + # # Reset layer dict for new layer + # layer = l + # w_per_layer = dict() + + # w_per_layer[name] = fe_w_dict[fe_key] + + # print(fe_key) + + # plot_results(w_dict=w_per_layer, epoch=epoch, layer=l) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 7359d1403..3f8ecc3e8 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -298,6 +298,10 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_eps=self.cf.mlp_norm_eps, ) ) + if self.cf.get("ae_global_trailing_layer_norm", False): + self.ae_global_blocks.append( + torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) + ) def forward(self, tokens, use_reentrant): for block in self.ae_global_blocks: @@ -308,7 +312,7 @@ def forward(self, tokens, use_reentrant): class ForecastingEngine(torch.nn.Module): name: "ForecastingEngine" - def __init__(self, cf: Config, num_healpix_cells: int) -> None: + def __init__(self, cf: Config, num_healpix_cells: int, dim_aux: int = None) -> None: """ Initialize the ForecastingEngine with the configuration. @@ -333,7 +337,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_qk_lnorm=self.cf.fe_with_qk_lnorm, with_flash=self.cf.with_flash_attention, norm_type=self.cf.norm_type, - dim_aux=1, + dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), ) @@ -349,7 +353,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_qk_lnorm=self.cf.fe_with_qk_lnorm, with_flash=self.cf.with_flash_attention, norm_type=self.cf.norm_type, - dim_aux=1, + dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), ) @@ -362,10 +366,15 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_residual=True, dropout_rate=self.cf.fe_dropout_rate, norm_type=self.cf.norm_type, - dim_aux=1, + dim_aux=dim_aux, norm_eps=self.cf.mlp_norm_eps, ) ) + # Optionally, add LayerNorm after i-th layer + if i in self.cf.get("fe_layer_norm_after_blocks", []): + self.fe_blocks.append( + torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) + ) def init_weights_final(m): if isinstance(m, torch.nn.Linear): @@ -377,10 +386,12 @@ def init_weights_final(m): block.apply(init_weights_final) def forward(self, tokens, fstep): - aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") - for block in self.fe_blocks: - tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) - + aux_info = None + for b_idx, block in enumerate(self.fe_blocks): + if isinstance(block, torch.nn.modules.normalization.LayerNorm): + tokens = block(tokens) + else: + tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) return tokens diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 8a4524c1a..ba021860b 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -597,6 +597,28 @@ def forward_jac(self, *args): return tuple(preds_all[0]) + ######################################### + def plot_token_distribution(self, tokens, fstep): + # When validating (distributed setup), don't plot the token distribution + if tokens.dtype == torch.bfloat16: + return + + plot_path = Path(self.cf.run_path, self.cf.run_id, "plots", "ERA5", "latent_hists") + import os + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.hist(tokens.flatten().to("cpu").numpy(), bins=30) + if not hasattr(self, "xlim"): + self.xlim = np.array(ax.get_xlim()) + self.ylim = np.array(ax.get_ylim()) + ax.set_xlim(0.5 * self.xlim) + ax.set_ylim(self.ylim) + ax.set_title(f"Forecast step {fstep}") + os.makedirs(plot_path, exist_ok=True) + fig.savefig(plot_path / f"fstep_{str(fstep).zfill(3)}.png") + plt.close() + ######################################### def forward(self, model_params: ModelParams, batch, forecast_offset: int, forecast_steps: int): """Performs the forward pass of the model to generate forecasts @@ -626,6 +648,9 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens = self.assimilate_global(model_params, tokens) + if not self.training: + self.plot_token_distribution(tokens=tokens, fstep=0) + # roll-out in latent space preds_all = [] for fstep in range(forecast_offset, forecast_offset + forecast_steps): @@ -648,6 +673,9 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens = self.forecast(model_params, tokens, fstep) + if not self.training: + self.plot_token_distribution(tokens=tokens, fstep=fstep) + # prediction for final step preds_all += [ self.predict( diff --git a/src/weathergen/train/loss.py b/src/weathergen/train/loss.py index 406cd051c..5f620a810 100644 --- a/src/weathergen/train/loss.py +++ b/src/weathergen/train/loss.py @@ -195,3 +195,24 @@ def gamma_decay(forecast_steps, gamma): fsteps = np.arange(forecast_steps) weights = gamma**fsteps return weights * (len(fsteps) / np.sum(weights)) + + +def spike_function(forecast_steps, spike_type): + fstep = np.arange(forecast_steps) + weights = np.zeros_like(fstep, dtype=float) + if spike_type["type"] == "last": + weights[-1] = 1.0 + elif spike_type["type"] == "probability": + steps_probs = spike_type["values"] + fs_steps = list(steps_probs.keys()) + fs_steps = [int(x) for x in fs_steps] + assert max(fs_steps) <= forecast_steps, ( + f"Max step {max(fs_steps)} > forecast_steps {forecast_steps}" + ) + fs_probs = list(steps_probs.values()) + assert np.isclose(np.array(fs_probs).sum(), 1.0) + fs_selected = np.random.choice(fs_steps, p=fs_probs) + weights[fs_selected] = 1.0 + else: + raise ValueError(f"Spike type {spike_type} is not defined") + return weights diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 796bd2b3f..6473d156f 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -339,15 +339,20 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): if is_root() and not cf.with_fsdp and not cf.with_ddp: self.model.print_num_parameters() + # Retrieve Adam betas from config or compute them dynamically if not specified + beta1, beta2 = cf.get("adam_beta1", None), cf.get("adam_beta2", None) + # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ # aiming for beta1=0.9 and beta2=0.95 following the MAE paper https://arxiv.org/pdf/2111.06377 kappa = ( cf.batch_size_per_gpu * cf.world_size ) # I doubt this holds for us from some anecdotal runs + # aiming for beta1 = 0.9 at one node, ie kappa=B=4 beta1 = max( 0.5, 1.0 - kappa * (1.0 - 0.975) - ) # aiming for beta1 = 0.9 at one node, ie kappa=B=4 - beta2 = 1.0 - kappa * (1.0 - 0.9875) # aiming for beta2 = 0.95 at one node, ie B=4 + ) if beta1 is None else beta1 + # aiming for beta2 = 0.95 at one node, ie B=4 + beta2 = 1.0 - kappa * (1.0 - 0.9875) if beta2 is None else beta2 eps = 2e-08 / np.sqrt(kappa) self.optimizer = torch.optim.AdamW( @@ -588,6 +593,7 @@ def train(self, mini_epoch): # Unweighted loss, real weighted loss, std for losses that need it self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] + self.last_grad_norm = 0.0 # training loop self.t_start = time.time() @@ -795,6 +801,7 @@ def load_model(self, run_id: str, mini_epoch=-1): is_model_sharded = self.cf.with_ddp and self.cf.with_fsdp if is_model_sharded: + params = self.model.rename_old_state_dict(params=params) # For backward compatibility meta_sharded_sd = self.model.state_dict() maybe_sharded_sd = {} for param_name, full_tensor in params.items(): diff --git a/src/weathergen/utils/plot_grad_norms.py b/src/weathergen/utils/plot_grad_norms.py new file mode 100644 index 000000000..ec310c0fc --- /dev/null +++ b/src/weathergen/utils/plot_grad_norms.py @@ -0,0 +1,525 @@ +import json +import re +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +# ruff: noqa: T201 + + +class GradientNormsAnalyzer: + def __init__(self, json_file_path): + """ + Initialize the analyzer with path to JSON file containing gradient norms. + Expected format: one JSON object per line with step info and gradient norms. + """ + self.json_file_path = Path(json_file_path) + self.data = [] + self.df = None + self.load_data() + + def load_data(self): + """Load and parse the JSON data from file.""" + print(f"Loading data from {self.json_file_path}...") + + with open(self.json_file_path) as f: + for line_num, line in enumerate(f, 1): + try: + data_point = json.loads(line.strip()) + self.data.append(data_point) + except json.JSONDecodeError as e: + print(f"Warning: Could not parse line {line_num}: {e}") + + print(f"Loaded {len(self.data)} data points") + self.create_dataframe() + + def create_dataframe(self): + """Convert loaded data into a pandas DataFrame for easier analysis.""" + rows = [] + + for ith, entry in enumerate(self.data): + # step = entry.get('num_samples', entry.get('epoch', 0)) + step = ith * 5 + + # Handle different possible data structures + if "gradients" in entry: + grad_data = entry["gradients"] + elif "grad_norms" in entry: + grad_data = entry["grad_norms"] + else: + # Assume all keys except step/epoch are gradient data + grad_data = { + k: v for k, v in entry.items() if "stream" not in k and ("grad_norm" in k) + } + + for param_name, norm_value in grad_data.items(): + rows.append( + { + "num_samples": step, + "parameter": param_name, + "grad_norm": float(norm_value), + "layer_type": self.extract_layer_type(param_name), + "layer_depth": self.extract_layer_depth(param_name), + } + ) + + self.df = pd.DataFrame(rows) + print(f"Created DataFrame with {len(self.df)} gradient norm records") + + def extract_layer_type(self, param_name): + """Extract layer type from parameter name.""" + param_name_lower = param_name.lower()[10:] + + # Handle your specific naming patterns + if param_name_lower.startswith("embeds."): + if ".embed." in param_name_lower: + return "embedding" + elif ".unembed." in param_name_lower: + return "unembedding" + elif ".ln_final." in param_name_lower: + return "layer_norm_final" + elif "proj_heads_q" in param_name_lower: + return "attention_q" + elif "proj_heads_k" in param_name_lower: + return "attention_k" + elif "proj_heads_v" in param_name_lower: + return "attention_v" + elif "proj_out" in param_name_lower: + return "attention_out" + elif ".layers." in param_name_lower and ( + "weight" in param_name_lower or "bias" in param_name_lower + ): + return "ffn" + else: + return "embeds_other" + + elif param_name_lower.startswith("ae_local_blocks."): + if "proj_heads_q" in param_name_lower: + return "ae_local_attention_q" + elif "proj_heads_k" in param_name_lower: + return "ae_local_attention_k" + elif "proj_heads_v" in param_name_lower: + return "ae_local_attention_v" + elif "proj_out" in param_name_lower: + return "ae_local_attention_out" + elif ".layers." in param_name_lower: + return "ae_local_ffn" + else: + return "ae_local_other" + + elif param_name_lower.startswith("ae_global_blocks."): + if "proj_heads_q" in param_name_lower: + return "ae_global_attention_q" + elif "proj_heads_k" in param_name_lower: + return "ae_global_attention_k" + elif "proj_heads_v" in param_name_lower: + return "ae_global_attention_v" + elif "proj_out" in param_name_lower: + return "ae_global_attention_out" + elif ".layers." in param_name_lower: + return "ae_global_ffn" + else: + return "ae_global_other" + + elif param_name_lower.startswith("ae_adapter."): + if "proj_heads_q" in param_name_lower: + return "ae_adapter_attention_q" + elif "proj_heads_k" in param_name_lower: + return "ae_adapter_attention_k" + elif "proj_heads_v" in param_name_lower: + return "ae_adapter_attention_v" + elif "proj_out" in param_name_lower: + return "ae_adapter_attention_out" + elif ".layers." in param_name_lower: + return "ae_adapter_ffn" + else: + return "ae_adapter_other" + + elif param_name_lower.startswith("target_token_engines."): + if "proj_heads_q" in param_name_lower: + return "tte_attention_q" + elif "proj_heads_k" in param_name_lower: + return "tte_attention_k" + elif "proj_heads_v" in param_name_lower: + return "tte_attention_v" + elif "proj_out" in param_name_lower: + return "tte_attention_out" + elif "embed_aux" in param_name_lower: + return "tte_embed_aux" + elif "lnorm" in param_name_lower: + return "tte_layer_norm" + elif ".layers." in param_name_lower: + return "tte_ffn" + else: + return "tte_other" + + elif param_name_lower.startswith("embed_target_coords."): + return "target_coords_embedding" + + elif param_name_lower.startswith("pred_heads."): + return "prediction_head" + + # Fallback for standard patterns (if any) + elif "embed" in param_name_lower: + return "embedding" + elif "attention" in param_name_lower or "attn" in param_name_lower: + if "q_proj" in param_name_lower or "query" in param_name_lower: + return "attention_q" + elif "k_proj" in param_name_lower or "key" in param_name_lower: + return "attention_k" + elif "v_proj" in param_name_lower or "value" in param_name_lower: + return "attention_v" + elif "o_proj" in param_name_lower or "out" in param_name_lower: + return "attention_out" + else: + return "attention" + elif ( + "layernorm" in param_name_lower + or "layer_norm" in param_name_lower + or "ln" in param_name_lower + ): + return "layernorm" + else: + return "other" + + def extract_layer_depth(self, param_name): + """Extract layer depth/index from parameter name.""" + param_name_lower = param_name.lower() + + # Look for patterns specific to your architecture + patterns = [ + # embeds.0.layers.N.* (transformer layers within embeds) + r"grad_norm_embeds\.\d+\.layers\.(\d+)\.", + # embeds.0.unembed.N.* (unembedding layers) + r"grad_norm_embeds\.\d+\.unembed\.(\d+)\.", + # embeds.0.ln_final.N.* (final layer norms) + r"grad_norm_embeds\.\d+\.ln_final\.(\d+)\.", + # ae_local_blocks.N.* (autoencoder local blocks) + r"grad_norm_ae_local_blocks\.(\d+)\.", + # ae_global_blocks.N.* (autoencoder global blocks) + r"ae_global_blocks\.(\d+)\.", + # ae_adapter.N.* (autoencoder adapter blocks) + r"ae_adapter\.(\d+)\.", + # target_token_engines.0.tte.N.* (target token engine blocks) + r"target_token_engines\.\d+\.tte\.(\d+)\.", + # target_token_engines.0.tte.N.block.M.* (nested blocks) + r"target_token_engines\.\d+\.tte\.(\d+)\.block\.(\d+)\.", + # pred_heads.0.pred_heads.0.N.* (prediction head layers) + r"pred_heads\.\d+\.pred_heads\.\d+\.(\d+)\.", + # Generic patterns for any numbered layers + r"layer[s]?\.(\d+)", + r"h\.(\d+)", + r"transformer\.(\d+)", + r"blocks\.(\d+)", + ] + + for pattern in patterns: + match = re.search(pattern, param_name_lower) + if match: + # For nested patterns (like tte blocks), combine indices + if len(match.groups()) > 1: + # Combine indices: e.g., tte.1.block.2 -> 12 (or 1*10+2) + return int(match.group(1)) * 10 + int(match.group(2)) + else: + return int(match.group(1)) + + # Special handling for components without clear depth + if param_name_lower.startswith("embed_target_coords."): + return 0 # Coordinate embeddings at the start + elif "total_grad_norm" in param_name_lower: + return -2 # Special marker for total norm + elif any(x in param_name_lower for x in ["weathergen", "stage", "q_cells"]): + return -3 # Special marker for metadata + + return -1 # Unknown depth + + def plot_total_gradient_norms(self, figsize=(12, 6)): + """Plot total gradient norm over training steps.""" + # Calculate total norm per step + total_norms = [] + steps = [] + + for ith, entry in enumerate(self.data): + # step = entry.get('num_samples', entry.get('epoch', 0)) + step = ith * 5 + + if "gradients" in entry: + grad_data = entry["gradients"] + elif "grad_norms" in entry: + grad_data = entry["grad_norms"] + else: + grad_data = {k: v for k, v in entry.items() if "grad_norm" in k} + + if len(grad_data) == 0: + continue + + # Calculate total norm (L2 norm of all gradients) + total_norm = np.sqrt(sum(float(v) ** 2 for v in grad_data.values())) + total_norms.append(total_norm) + steps.append(step) + + plt.figure(figsize=figsize) + plt.plot(steps, total_norms, linewidth=1.5, alpha=0.8) + plt.xlabel("Training Step") + plt.ylabel("Total Gradient Norm") + plt.title("Total Gradient Norm vs Training Steps") + plt.yscale("log") + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig("plots/total_grad_norm.png") + + return steps, total_norms + + def plot_layer_type_norms(self, figsize=(14, 8)): + """Plot gradient norms grouped by layer type.""" + if self.df is None: + print("No DataFrame available. Load data first.") + return + + plt.figure(figsize=figsize) + + # Get unique layer types + layer_types = self.df["layer_type"].unique() + print(layer_types) + colors = plt.cm.tab10(np.linspace(0, 1, len(layer_types))) + + for i, layer_type in enumerate(layer_types): + layer_data = self.df[self.df["layer_type"] == layer_type] + + # Calculate mean gradient norm per step for this layer type + mean_norms = layer_data.groupby("num_samples")["grad_norm"].mean() + + plt.plot( + mean_norms.index, mean_norms.values, label=layer_type, color=colors[i], alpha=0.8 + ) + + plt.xlabel("Training Step") + plt.ylabel("Mean Gradient Norm") + plt.title("Gradient Norms by Layer Type") + plt.yscale("log") + plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig("plots/grad_norm_by_layer_type.png") + + def plot_layer_depth_analysis(self, figsize=(12, 8)): + """Plot gradient norms by layer depth.""" + if self.df is None: + print("No DataFrame available. Load data first.") + return + + # Filter out unknown depths + depth_data = self.df[self.df["layer_depth"] >= 0] + + if len(depth_data) == 0: + print("No layer depth information found in parameter names.") + return + + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize) + + # Plot 1: Mean gradient norm by depth over time + depths = sorted(depth_data["layer_depth"].unique()) + colors = plt.cm.viridis(np.linspace(0, 1, len(depths))) + + for i, depth in enumerate(depths): + layer_data = depth_data[depth_data["layer_depth"] == depth] + mean_norms = layer_data.groupby("num_samples")["grad_norm"].mean() + + ax1.plot( + mean_norms.index, + mean_norms.values, + label=f"Layer {depth}", + color=colors[i], + alpha=0.8, + ) + + ax1.set_xlabel("Training Step") + ax1.set_ylabel("Mean Gradient Norm") + ax1.set_title("Gradient Norms by Layer Depth") + ax1.set_yscale("log") + ax1.legend(bbox_to_anchor=(1.05, 1), loc="upper left") + ax1.grid(True, alpha=0.3) + + # Plot 2: Heatmap of gradient norms by depth and step + pivot_data = ( + depth_data.groupby(["num_samples", "layer_depth"])["grad_norm"].mean().unstack() + ) + + # Sample data if too many steps for readability + if len(pivot_data) > 100: + sample_idx = np.linspace(0, len(pivot_data) - 1, 100, dtype=int) + pivot_data = pivot_data.iloc[sample_idx] + + im = ax2.imshow( + pivot_data.T, + aspect="auto", + cmap="viridis", + extent=[ + pivot_data.index.min(), + pivot_data.index.max(), + pivot_data.columns.min(), + pivot_data.columns.max(), + ], + ) + ax2.set_xlabel("Training Step") + ax2.set_ylabel("Layer Depth") + ax2.set_title("Gradient Norm Heatmap (Layer Depth vs Step)") + + cbar = plt.colorbar(im, ax=ax2) + cbar.set_label("Gradient Norm") + + plt.tight_layout() + plt.savefig("plots/grad_norm_heatmap.png") + + def plot_gradient_distribution(self, figsize=(15, 10)): + """Plot distribution of gradient norms.""" + if self.df is None: + print("No DataFrame available. Load data first.") + return + + fig, axes = plt.subplots(2, 2, figsize=figsize) + + # Plot 1: Histogram of all gradient norms + axes[0, 0].hist(np.log10(self.df["grad_norm"].values), bins=50, alpha=0.7) + axes[0, 0].set_xlabel("Log10(Gradient Norm)") + axes[0, 0].set_ylabel("Frequency") + axes[0, 0].set_title("Distribution of Gradient Norms (Log Scale)") + axes[0, 0].grid(True, alpha=0.3) + + # Plot 2: Box plot by layer type + layer_types = self.df["layer_type"].unique()[:10] # Limit to 10 for readability + plot_data = [ + np.log10(self.df[self.df["layer_type"] == lt]["grad_norm"].values) for lt in layer_types + ] + + axes[0, 1].boxplot(plot_data, labels=layer_types) + axes[0, 1].set_xlabel("Layer Type") + axes[0, 1].set_ylabel("Log10(Gradient Norm)") + axes[0, 1].set_title("Gradient Norm Distribution by Layer Type") + axes[0, 1].tick_params(axis="x", rotation=45) + axes[0, 1].grid(True, alpha=0.3) + + # Plot 3: Gradient norms over time (sample of parameters) + sample_params = self.df["parameter"].unique()[:20] # Sample 20 parameters + for param in sample_params: + param_data = self.df[self.df["parameter"] == param] + axes[1, 0].plot( + param_data["num_samples"], param_data["grad_norm"], alpha=0.6, linewidth=0.8 + ) + + axes[1, 0].set_xlabel("Training Step") + axes[1, 0].set_ylabel("Gradient Norm") + axes[1, 0].set_title("Individual Parameter Gradient Norms (Sample)") + axes[1, 0].set_yscale("log") + axes[1, 0].grid(True, alpha=0.3) + + # Plot 4: Statistics over time + stats_by_step = self.df.groupby("num_samples")["grad_norm"].agg( + ["mean", "std", "min", "max"] + ) + + axes[1, 1].fill_between( + stats_by_step.index, + stats_by_step["mean"] - stats_by_step["std"], + stats_by_step["mean"] + stats_by_step["std"], + alpha=0.3, + label="±1 std", + ) + axes[1, 1].plot(stats_by_step.index, stats_by_step["mean"], label="Mean", linewidth=2) + axes[1, 1].plot( + stats_by_step.index, stats_by_step["max"], label="Max", linewidth=1, alpha=0.8 + ) + axes[1, 1].plot( + stats_by_step.index, stats_by_step["min"], label="Min", linewidth=1, alpha=0.8 + ) + + axes[1, 1].set_xlabel("Training Step") + axes[1, 1].set_ylabel("Gradient Norm") + axes[1, 1].set_title("Gradient Norm Statistics Over Time") + axes[1, 1].set_yscale("log") + axes[1, 1].legend() + axes[1, 1].grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig("plots/grad_norm_over_time.png") + + def generate_summary_report(self): + """Generate a summary report of gradient norm statistics.""" + if self.df is None: + print("No DataFrame available. Load data first.") + return + + print("=== GRADIENT NORMS ANALYSIS REPORT ===") + print(f"Total data points: {len(self.df)}") + print(f"Training steps: {self.df['num_samples'].nunique()}") + print(f"Unique parameters: {self.df['parameter'].nunique()}") + print() + + print("Overall Statistics:") + print(f"Mean gradient norm: {self.df['grad_norm'].mean():.6f}") + print(f"Median gradient norm: {self.df['grad_norm'].median():.6f}") + print(f"Min gradient norm: {self.df['grad_norm'].min():.6f}") + print(f"Max gradient norm: {self.df['grad_norm'].max():.6f}") + print() + + print("Statistics by Layer Type:") + layer_stats = self.df.groupby("layer_type")["grad_norm"].agg( + ["count", "mean", "std", "min", "max"] + ) + print(layer_stats) + print() + + # Check for potential issues + print("Potential Issues:") + very_small = (self.df["grad_norm"] < 1e-6).sum() + very_large = (self.df["grad_norm"] > 10.0).sum() + + if very_small > 0: + print(f"⚠️ {very_small} gradient norms < 1e-6 (possible vanishing gradients)") + if very_large > 0: + print(f"⚠️ {very_large} gradient norms > 10.0 (possible exploding gradients)") + + if very_small == 0 and very_large == 0: + print("✅ No obvious gradient issues detected") + + +# Usage example +def analyze_gradient_file(json_file_path): + """ + Main function to analyze gradient norms from a JSON file. + + Usage: + analyze_gradient_file('gradient_norms.jsonl') + """ + + analyzer = GradientNormsAnalyzer(json_file_path) + + # Generate summary report + analyzer.generate_summary_report() + + # Create all plots + print("\n=== GENERATING PLOTS ===") + + print("1. Total gradient norms over time...") + analyzer.plot_total_gradient_norms() + + print("2. Gradient norms by layer type...") + analyzer.plot_layer_type_norms() + + print("3. Layer depth analysis...") + analyzer.plot_layer_depth_analysis() + + print("4. Gradient distribution analysis...") + analyzer.plot_gradient_distribution() + + return analyzer + + +# Example usage: +# uv run python src/weathergen/utils/plot_grad_norms.py results/yvhxm2jc/yvhxm2jc_train_metrics.json +if __name__ == "__main__": + import sys + + analyzer = analyze_gradient_file(sys.argv[1]) diff --git a/uv.lock b/uv.lock index 4cdcbdcc5..d10a5e2dd 100644 --- a/uv.lock +++ b/uv.lock @@ -2728,6 +2728,7 @@ dependencies = [ { name = "polars", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "psutil", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "pynvml", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "pyyaml", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "tqdm", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "weathergen-common", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "weathergen-evaluate", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, @@ -2777,6 +2778,7 @@ requires-dist = [ { name = "polars", specifier = "~=1.25.2" }, { name = "psutil" }, { name = "pynvml" }, + { name = "pyyaml", specifier = ">=6.0.2" }, { name = "torch", marker = "platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-linux_aarch64.whl" }, { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl" }, { name = "torch", marker = "sys_platform == 'linux' and extra == 'cpu'", specifier = "==2.6.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "weathergen", extra = "cpu" } }, diff --git a/weather_slurm_inferece.sh b/weather_slurm_inferece.sh new file mode 100644 index 000000000..7818b2eb6 --- /dev/null +++ b/weather_slurm_inferece.sh @@ -0,0 +1,117 @@ +#!/bin/bash + +#SBATCH --job-name=train +#SBATCH --output=./logs/output_%j.txt +#SBATCH --error=./logs/error_%j.txt +#SBATCH --exclusive --mem=450G +#SBATCH --partition=normal +#SBATCH --gres=gpu:1 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --time=01:00:00 +#SBATCH -A ch17 +#SBATCH --output=logs/weathergen-%x.%j.out +#SBATCH --error=logs/weathergen-%x.%j.err + + +UENV_IMAGE="prgenv-gnu/25.6:v2" + +FROM_RUN_ID="$1" +#RUN_ID="$2" + +export FROM_RUN_ID +#export RUN_ID + +echo "Top-level from_run_id: $FROM_RUN_ID" +echo "Top-level run_id: $RUN_ID" + +echo "=== Checking for uenv image: $UENV_IMAGE ===" + +IMAGE_EXISTS=false + +if [ "$IMAGE_EXISTS" = false ]; then + if uenv image inspect "$UENV_IMAGE" &>/dev/null; then + IMAGE_EXISTS=true + fi +fi + +if [ "$IMAGE_EXISTS" = false ]; then + echo "========================================" + echo "ERROR: uenv image '$UENV_IMAGE' not found!" + echo "========================================" + echo "" + echo "The image needs to be pulled before use." + echo "" + echo "Steps to fix:" + echo "" + echo " 1. On the santis login node, run:" + echo " uenv image pull $UENV_IMAGE" + echo "" + echo " 2. Wait for download to complete (this may take a few minutes)" + echo "" + echo " 3. Verify the image is available:" + echo " uenv image ls" + echo "" + echo " 4. Re-submit your SLURM job" + echo "" + echo "========================================" + exit 1 +fi + +echo "✓ Image '$UENV_IMAGE' found" +echo "" + +FROM_RUN_ID="$1" +#RUN_ID="$2" + +uenv run "$UENV_IMAGE" --view=modules -- bash << 'EOF' + +module load aws-ofi-nccl/1.16.0 + +export NCCL_NET="AWS Libfabric" +export MPICH_GPU_SUPPORT_ENABLED=0 +export NCCL_NET_GDR_LEVEL=PHB +export NCCL_CROSS_NIC=1 +export NCCL_PROTO=^LL128 + +export FI_CXI_DEFAULT_CQ_SIZE=131072 +export FI_CXI_DEFAULT_TX_SIZE=16384 +export FI_CXI_DISABLE_HOST_REGISTER=1 +export FI_CXI_RX_MATCH_MODE=software +export FI_MR_CACHE_MONITOR=userfaultfd + +export MASTER_ADDR="$(scontrol show hostnames "$SLURM_NODELIST" | head -n 1)" +export MASTER_PORT=29514 + +# disable core dumps +ulimit -c 0 +ulimit -t unlimited + +export CC=/usr/bin/gcc +export NCCL_DEBUG=INFO + +echo "Starting job." +echo "Number of Nodes: $SLURM_JOB_NUM_NODES" +echo "Number of Tasks: $SLURM_NTASKS" +echo "from_run_id: $FROM_RUN_ID" +#echo "run_id: $RUN_ID" +echo "WEATHERGEN_HOME: $WEATHERGEN_HOME" +echo "WEATHERGEN_CONFIG_EXTRA: $WEATHERGEN_CONFIG_EXTRA" +echo "SLURM_JOB_ID: $SLURM_JOB_ID" +echo "SLURM_JOB_NAME: $SLURM_JOB_NAME" +echo "SLURM_SUBMIT_DIR: $SLURM_SUBMIT_DIR" +echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST" +date + + +#cd $WEATHERGEN_HOME +source .venv/bin/activate + +srun uv run --offline inference --from_run_id "$FROM_RUN_ID" --samples=16 --start_date=2023-10-01 --end_date=2023-12-01 --options forecast_steps=80 +#srun uv run inference --from_run_id "$FROM_RUN_ID" --run_id "$RUN_ID" --samples 16 --start_date=2023-10-01 --end_date=2023-12-01 --options forecast_steps=80 + +echo "Finished job." +sstat -j $SLURM_JOB_ID.batch --format=JobID,MaxVMSize +date +EOF +