Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions config/evaluate/eval_config_lst.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" ..
dpi_val : 300
summary_plots : true
print_summary: true

evaluation:
metrics : ["rmse", "mae"]
regions: ["madagaskar"]
summary_dir: "./plots/"
plot_score_maps: false #plot scores on a 2D maps. it slows down score computation
print_summary: false #print out score values on screen. it can be verbose

run_ids :

ndl2qget : # Inference run id.
label: "One-shot LST prediction"
mini_epoch: 0
rank: 0
streams:
SEVIRI_LST:
channels: ["LST"] #["2t", "q_850", ] #["LST"] # ["LST"] #["2t", "q_850", ]
evaluation:
sample: "all"
forecast_step: "all"
plotting:
sample: [0, 1]
forecast_step: [ 1, 2, 3, 4, 5, 6] #, 2, 3, 4] #, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
plot_maps: true
plot_histograms: true
194 changes: 194 additions & 0 deletions config/lst_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
streams_directory: "./config/streams/seviri_lst/"

embed_orientation: "channels"
embed_local_coords: True
embed_centroids_local_coords: False
embed_size_centroids: 0
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

target_cell_local_prediction: True

ae_local_dim_embed: 1024
ae_local_num_blocks: 2
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 8
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
# TODO: switching to < 1 triggers triton-related issues.
# See https://github.com/ecmwf/WeatherGenerator/issues/1050
ae_global_att_dense_rate: 1.0
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2

ae_aggregation_num_blocks: 2
ae_aggregation_num_heads: 32
ae_aggregation_dropout_rate: 0.1
ae_aggregation_with_qk_lnorm: True
ae_aggregation_att_dense_rate: 1.0
ae_aggregation_block_factor: 64
ae_aggregation_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 0
forecast_delta_hrs: 0
forecast_steps: 0
forecast_policy: null
forecast_att_dense_rate: 1.0
fe_num_blocks: 0
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
impute_latent_noise_std: 0.0 # 1e-4

healpix_level: 5

with_mixed_precision: True
with_flash_attention: True
compile_model: False
with_fsdp: True
attention_dtype: bf16
mixed_precision_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0 # 1e-5
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True

batch_size_per_gpu: 1
batch_size_validation_per_gpu: 1

# a regex that needs to fully match the name of the modules you want to freeze
# e.g. ".*ERA5" will match any module whose name ends in ERA5\
# encoders and decoders that exist per stream have the stream name attached at the end
freeze_modules: ""

# whether to track the exponential moving average of weights for validation
validate_with_ema: True
ema_ramp_up_ratio: 0.09
ema_halflife_in_thousands: 1e-3

# training mode: "forecast" or "masking" (masked token modeling)
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
training_mode: "masking"
training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]},}
}
# training_mode_config: {"loss": {LossPhysical: [['mse', 0.7]],
# LossLatent: [['mse', 0.3]],
# LossStudentTeacher: [{'iBOT': {<options>}, 'JEPA': {options}}],}
# }
validation_mode_config: {"losses": {LossPhysical: {weight: 1.0, loss_fcts: [['mse', 1.0]]},}
}
# masking rate when training mode is "masking"; ignored in foreacast mode
masking_rate: 0.6
# sample the masking rate (with normal distribution centered at masking_rate)
# note that a sampled masking rate leads to varying requirements
masking_rate_sampling: True
# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream)
sampling_rate_target: 1.0
# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination"
masking_strategy: "random"
# masking_strategy_config is a dictionary of additional parameters for the masking strategy
# required for "healpix" and "channel" masking strategies
# "healpix": requires healpix mask level to be specified with `hl_mask`
# "channel": requires "mode" to be specified, "per_cell" or "global",
masking_strategy_config: {"strategies": ["random", "healpix", "channel"],
"probabilities": [0.34, 0.33, 0.33],
"hl_mask": 3, "mode": "per_cell",
"same_strategy_per_batch": false
}

num_mini_epochs: 32
samples_per_mini_epoch: 4096
samples_per_validation: 512

shuffle: True

lr_scaling_policy: "sqrt"
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 1e-6
lr_final: 0.0
lr_steps_warmup: 512
lr_steps_cooldown: 512
lr_policy_warmup: "cosine"
lr_policy_decay: "constant"
lr_policy_cooldown: "linear"

grad_clip: 1.0
weight_decay: 0.1
norm_type: "LayerNorm"
nn_module: "te"
log_grad_norms: False

start_date: 197901010000
end_date: 202012310000
start_date_val: 201705010000 #202101010000
end_date_val: 20170630000 #202201010000
len_hrs: 6
step_hrs: 6
input_window_steps: 1

val_initial: False

loader_num_workers: 8
log_validation: 0
streams_output: ["ERA5"]

istep: 0
run_history: []

desc: ""
data_loader_rng_seed: ???
run_id: ???

# The period to log in the training loop (in number of batch steps)
train_log_freq:
terminal: 10
metrics: 20
checkpoint: 250


# Tags for experiment tracking
# These tags will be logged in MLFlow along with completed runs for train, eval, val
# The tags are free-form, with the following rules:
# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries
# - tags should not duplicate existing config entries.
# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags
# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future)
wgtags:
# The name of the organization of the person running the experiment.
# This may be autofilled in the future. Expected values are lowercase strings of
# the organizations codenames in https://confluence.ecmwf.int/display/MAEL/Staff+Contact+List
# e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience"
org: mpg
# The name of the experiment. This is a distinctive codename for the experiment campaign being run.
# This is expected to be the primary tag for comparing experiments in MLFlow.
# Expected values are lowercase strings with no spaces, just underscores:
# Examples: "rollout_ablation_grid"
exp: lst_finetune
# *** Experiment-specific tags ***
grid: v0
30 changes: 30 additions & 0 deletions config/streams/seviri_lst/era5.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
ERA5 :
type : anemoi
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr']
stream_id : 0
source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp']
target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp']
loss_weight : 1.
location_weight : cosine_latitude
masking_rate : 0.6
masking_rate_none : 0.05
token_size : 8
tokenize_spacetime : True
max_num_targets: -1
forcing: True
embed :
net : transformer
num_tokens : 1
num_heads : 8
dim_embed : 512
num_blocks : 2
embed_target_coords :
net : linear
dim_embed : 512
target_readout :
num_layers : 2
num_heads : 4
# sampling_rate : 0.2
pred_head :
ens_size : 1
num_layers : 1
36 changes: 36 additions & 0 deletions config/streams/seviri_lst/seviri_lst.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
SEVIRI_LST :
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove from the PR. We need a separate repo for the configs and are in the process of creating it.

type : msg_lst
stream_id: 1
filenames : ['mpg_seviri_l2_2017-18_v0/lst_test.zarr'] # use ['mpg_seviri_l2_2017-18_v0/seviri.zarr'] after zarr3 is enabled
data_start_time : "2017-02-01 00:00"
data_end_time : "2017-06-30 00:00"
target: ["LST"]
source: []
geoinfos: [] #["DEM"] #, "LANDCOV"]
metadata: "/leonardo_work/AIFAC_5C0_154/weathergen/data/mpg_seviri_l2_2017-18_v1/metadata" # uses one scene over south africa for finetuning
scene: "scenes_train_scene_001.npz"
spatial_stride: 24
temporal_stride: 6
sampling_rate_target: 0.005 # use 10% of spatial points
loss_weight : 1.0
masking_rate : 0.6
masking_rate_none : 0.05
token_size : 64
tokenize_spacetime : True
max_num_targets: -1 #-1
embed :
net : transformer
num_tokens : 1
num_heads : 2
dim_embed : 16
num_blocks : 2
embed_target_coords :
net : linear
dim_embed : 16
target_readout :
type : 'obs_value'
num_layers : 2
num_heads : 4
pred_head :
ens_size : 1
num_layers : 1
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]:
image_paths += names

if image_paths:
image_paths=sorted(image_paths)
image_paths = sorted(image_paths)
images = [Image.open(path) for path in image_paths]
images[0].save(
f"{map_output_dir}/animation_{self.run_id}_{tag}_{sa}_{self.stream}_{region}_{var}.gif",
Expand Down
1 change: 1 addition & 0 deletions packages/evaluate/src/weathergen/evaluate/utils/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class RegionLibrary:
"shem": (-90.0, 0.0, -180.0, 180.0),
"tropics": (-30.0, 30.0, -180.0, 180.0),
"belgium": (49, 52, 2, 7),
"madagaskar": (-25, -10, 43, 50),
}


Expand Down
Loading