From 317501e282fba573c43ac64ba3f32dac42aa5511 Mon Sep 17 00:00:00 2001 From: wang85 Date: Wed, 16 Jul 2025 10:07:48 +0200 Subject: [PATCH 01/36] Replace cf.rank==0 with utils.distributed.is_root --- src/weathergen/train/trainer.py | 16 +++++----- src/weathergen/utils/plot_training.py | 44 ++++++++++++++++++--------- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index d5d63d62f..eb5e8deca 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -68,7 +68,7 @@ def init( cf.model_path = cf.model_path if hasattr(cf, "model_path") else "./models" path_run = Path(cf.run_path) / cf.run_id path_model = Path(cf.model_path) / cf.run_id - if self.cf.rank == 0: + if is_root(): path_run.mkdir(exist_ok=True, parents=True) path_model.mkdir(exist_ok=True, parents=True) self.path_run = path_run @@ -121,7 +121,7 @@ def inference(self, cf, run_id_trained, epoch): for name, w in cf.loss_fcts_val: self.loss_fcts_val += [[getattr(losses, name), w]] - if self.cf.rank == 0: + if is_root(): config.save(self.cf, epoch=0) _logger.info(f"Starting inference with id={self.cf.run_id}.") @@ -216,7 +216,7 @@ def run(self, cf, run_id_contd=None, epoch_contd=None): self.model_params = ModelParams().create(cf).to("cuda") # if with_fsdp then parameter count is unreliable - if (self.cf.rank == 0 and not cf.with_fsdp) or not cf.with_ddp: + if (is_root() and not cf.with_fsdp) or not cf.with_ddp: self.model.print_num_parameters() # TODO: learning rate schedule @@ -277,7 +277,7 @@ def run(self, cf, run_id_contd=None, epoch_contd=None): cf.lr_scaling_policy, ) - if self.cf.istep > 0 and self.cf.rank == 0: + if self.cf.istep > 0 and is_root(): str = f"Continuing run with learning rate: {self.lr_scheduler.get_lr()}" _logger.info(str) @@ -659,7 +659,7 @@ def validate(self, epoch): torch.stack(self.stddev_hist).to(torch.float64).nanmean(0) ) - if self.cf.rank == 0 and self.cf.istep >= 0: + if is_root() and self.cf.istep >= 0: loss_dict = {} for j, (lname, _) in enumerate(cf.loss_fcts_val): loss_dict[f"validation {lname}"] = torch.nanmean(losses_all[j]).item() @@ -673,7 +673,7 @@ def validate(self, epoch): samples = cf.istep * cf.batch_size * cf.num_ranks self.train_logger.add_val(samples, losses_all, stddev_all) - if self.cf.rank == 0: + if is_root(): print( f"validation ({cf.run_id}) : {epoch:03d} :", f" loss = {torch.nanmean(losses_all[0]):.4E}", @@ -738,7 +738,7 @@ def log(self, bidx): stddev_avg = self.ddp_average(torch.nanmean(torch.stack(self.stddev_hist), axis=0)) samples = self.cf.istep * self.cf.batch_size * self.cf.num_ranks - if self.cf.rank == 0: + if is_root(): # logging loss_dict = { "training mse": float(torch.nanmean(l_avg[0])), @@ -770,7 +770,7 @@ def log_terminal(self, bidx, epoch): nanmean(torch.stack(self.losses_hist[-self.print_freq :]), axis=0) ) - if self.cf.rank == 0: + if is_root(): # samples per sec dt = time.time() - self.t_start pstr = "{:03d} : {:05d}/{:05d} : {:06d} : loss = {:.4E} " diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 3fb6cb502..44ac6383e 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -10,6 +10,7 @@ import argparse import logging import subprocess +import sys from pathlib import Path import matplotlib.pyplot as plt @@ -21,6 +22,8 @@ _logger = logging.getLogger(__name__) +DEFAULT_RUN_FILE = Path("./config/runs_plot_train.yml") + #################################################################################################### def _ensure_list(value): @@ -556,7 +559,7 @@ def plot_loss_per_run( plt.close() -def plot_train(): +def plot_train(args=None): # Example usage: # When providing a YAML for configuring the run IDs: # python plot_training.py -rf eval_run.yml -m ./trained_models -o ./training_plots @@ -622,27 +625,24 @@ def plot_train(): run_id_group = parser.add_mutually_exclusive_group() run_id_group.add_argument( - "-rs", - "--run_ids_dict", + "-fd", + "--from_dict", type=_read_str_config, - dest="rs", - help=( - "Dictionary-string of form '{run_id: [job_id, experiment_name]}'", - " for training runs to plot", - ), + dest="fd", + help="Dictionary-string of form '{run_id: [job_id, experiment_name]}'" + + "for training runs to plot", ) run_id_group.add_argument( - "-rf", - "--run_ids_file", - dest="rf", - default="./config/runs_plot_train.yml", + "-fy", + "--from_yaml", + dest="fy", type=_read_yaml_config, help="YAML file configuring the training run ids to plot", ) # parse the command line arguments - args = parser.parse_args() + args = parser.parse_args(args) model_base_dir = Path(args.model_base_dir) out_dir = Path(args.output_dir) @@ -651,7 +651,17 @@ def plot_train(): if args.x_type not in x_types_valid: raise ValueError(f"x_type must be one of {x_types_valid}, but got {args.x_type}") - runs_ids = args.rs if args.rs is not None else args.rf + # Post-processing default logic for config from YAML-file + if args.fd is None and args.fy is None: + if DEFAULT_RUN_FILE.exists(): + args.fy = _read_yaml_config(DEFAULT_RUN_FILE) + else: + raise ValueError( + f"Please provide a run_id dictionary or a YAML file with run_ids, " + f"or create a default file at {DEFAULT_RUN_FILE}." + ) + + runs_ids = args.fd if args.fd is not None else args.fy if args.delete == "True": clean_plot_folder(out_dir) @@ -725,3 +735,9 @@ def plot_train(): get_stream_names(run_id, model_path=model_base_dir), # limit to available streams plot_dir=out_dir, ) + + +if __name__ == "__main__": + args = sys.argv[1:] # get CLI args + + plot_train(args) \ No newline at end of file From 734a96bfe3ae3a51413f3513c37c6c258b55d15c Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Tue, 9 Dec 2025 12:02:22 +0100 Subject: [PATCH 02/36] add 2d rope to develop --- config/default_config.yml | 5 ++ src/weathergen/model/attention.py | 11 ++- src/weathergen/model/engines.py | 20 +++-- src/weathergen/model/model.py | 99 +++++++++++++++------ src/weathergen/model/positional_encoding.py | 76 ++++++++++++++++ 5 files changed, 179 insertions(+), 32 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 66b57c865..c3d0cfc35 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -63,6 +63,11 @@ impute_latent_noise_std: 0.0 # 1e-4 healpix_level: 5 +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +use_2D_rope: True + with_mixed_precision: True with_flash_attention: True compile_model: False diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 39ed1c041..227e8c47a 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -14,6 +14,7 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention from weathergen.model.norms import AdaLayerNorm, RMSNorm +from weathergen.model.positional_encoding import apply_rotary_pos_emb_2d class MultiSelfAttentionHeadVarlen(torch.nn.Module): @@ -242,7 +243,7 @@ def mask_block_local(batch, head, idx_q, idx_kv): # compile for efficiency self.flex_attention = torch.compile(flex_attention, dynamic=False) - def forward(self, x, ada_ln_aux=None): + def forward(self, x, coords=None, ada_ln_aux=None): if self.with_residual: x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) @@ -253,6 +254,9 @@ def forward(self, x, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3]) vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3]) + if coords is not None: + qs, ks = apply_rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1) + outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) out = self.proj_out(self.dropout(outs.flatten(-2, -1))) @@ -527,7 +531,7 @@ def __init__( self.att = self.attention self.softmax = torch.nn.Softmax(dim=-1) - def forward(self, x, ada_ln_aux=None): + def forward(self, x, coords=None, ada_ln_aux=None): if self.with_residual: x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) @@ -539,6 +543,9 @@ def forward(self, x, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) vs = self.proj_heads_v(x).reshape(s).to(self.dtype) + if coords is not None: + qs, ks = apply_rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=2) + # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 190fd6548..b10632ff2 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -303,9 +303,12 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) ) - def forward(self, tokens, use_reentrant): + def forward(self, tokens, coords=None, use_reentrant=True): for block in self.ae_aggregation_blocks: - tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) + if isinstance(block, (MultiSelfAttentionHead, MultiSelfAttentionHeadLocal)): + tokens = checkpoint(block, tokens, coords, use_reentrant=use_reentrant) + else: + tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) return tokens @@ -371,9 +374,12 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) ) - def forward(self, tokens, use_reentrant): + def forward(self, tokens, coords=None, use_reentrant=True): for block in self.ae_global_blocks: - tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) + if isinstance(block, (MultiSelfAttentionHead, MultiSelfAttentionHeadLocal)): + tokens = checkpoint(block, tokens, coords, use_reentrant=use_reentrant) + else: + tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) return tokens @@ -451,7 +457,11 @@ def init_weights_final(m): def forward(self, tokens, fstep): aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") for block in self.fe_blocks: - tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) + if isinstance(block, (MultiSelfAttentionHead, MultiSelfAttentionHeadLocal)): + # No RoPE coords during forecasting; pass aux to AdaLayerNorm. + tokens = checkpoint(block, tokens, None, aux_info, use_reentrant=False) + else: + tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) return tokens diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index b773a4b96..5c5a8b68d 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -23,6 +23,7 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config +from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.engines import ( EmbeddingEngine, EnsPredictionHead, @@ -83,6 +84,23 @@ def __init__(self, cf) -> None: ) self.pe_global = torch.nn.Parameter(pe, requires_grad=False) + ### ROPE COORDS (for 2D RoPE when use_2D_rope=True) ### + # Precompute per-cell center coordinates (lat, lon in radians) for 2D RoPE. + # Shape: (num_healpix_cells, ae_local_num_queries, 2) + self.use_2D_rope = getattr(cf, "use_2D_rope", False) + if self.use_2D_rope: + self.rope_coords = torch.nn.Parameter( + torch.zeros( + self.num_healpix_cells, + cf.ae_local_num_queries, + 2, + dtype=torch.float32, + ), + requires_grad=False, + ) + else: + self.rope_coords = None + ### HEALPIX NEIGHBOURS ### hlc = self.healpix_level with warnings.catch_warnings(action="ignore"): @@ -150,28 +168,39 @@ def reset_parameters(self, cf: Config) -> "ModelParams": self.pe_embed.data[:, 1::2] = torch.cos(position * div[: self.pe_embed[:, 1::2].shape[1]]) dim_embed = cf.ae_global_dim_embed - self.pe_global.data.fill_(0.0) - xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed - self.pe_global.data[..., 0::2] = 0.5 * torch.sin( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) - ) - self.pe_global.data[..., 0::2] += ( - torch.sin( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + + if self.use_2D_rope: + # Precompute per-cell center coordinates (lat, lon in radians) for 2D RoPE. + verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) + coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) + coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) + self.rope_coords.data.copy_(coords) + # Clear pe_global when using 2D RoPE + self.pe_global.data.fill_(0.0) + else: + # Original pe_global initialization + self.pe_global.data.fill_(0.0) + xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed + self.pe_global.data[..., 0::2] = 0.5 * torch.sin( + torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) ) - .unsqueeze(1) - .repeat((1, cf.ae_local_num_queries, 1)) - ) - self.pe_global.data[..., 1::2] = 0.5 * torch.cos( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) - ) - self.pe_global.data[..., 1::2] += ( - torch.cos( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + self.pe_global.data[..., 0::2] += ( + torch.sin( + torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + ) + .unsqueeze(1) + .repeat((1, cf.ae_local_num_queries, 1)) + ) + self.pe_global.data[..., 1::2] = 0.5 * torch.cos( + torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) + ) + self.pe_global.data[..., 1::2] += ( + torch.cos( + torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + ) + .unsqueeze(1) + .repeat((1, cf.ae_local_num_queries, 1)) ) - .unsqueeze(1) - .repeat((1, cf.ae_local_num_queries, 1)) - ) # healpix neighborhood structure @@ -268,6 +297,8 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.sources_size = sources_size self.targets_num_channels = targets_num_channels self.targets_coords_size = targets_coords_size + + self.use_2D_rope = getattr(cf, "use_2D_rope", False) self.ae_aggregation_engine: QueryAggregationEngine | None = None self.ae_global_engine: GlobalAssimilationEngine | None = None @@ -781,15 +812,23 @@ def assimilate_local( # (applying this here assumes batch_size=1) # permute to use ae_local_num_queries as the batchsize and no_of_tokens # as seq len for flash attention + + # create mask from cell lens + mask = cell_lens.to(torch.bool) + + coords_global = None + if self.use_2D_rope: + coords_global = ( + model_params.rope_coords[mask] + .to(device=tokens_global_unmasked.device, dtype=tokens_global_unmasked.dtype) + .permute(1, 0, 2) + ) tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) tokens_global_unmasked = self.ae_aggregation_engine( - tokens_global_unmasked, use_reentrant=False + tokens_global_unmasked, coords=coords_global, use_reentrant=False ) tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) - # create mask from cell lens - mask = cell_lens.to(torch.bool) - # fill empty tensor using mask for positions of unmasked tokens tokens_global[mask] = tokens_global_unmasked.to(tokens_global.dtype) @@ -811,8 +850,18 @@ def assimilate_global(self, model_params: ModelParams, tokens: torch.Tensor) -> Latent representation of the model """ + batch_size = tokens.shape[0] + coords = None + if self.use_2D_rope: + coords = ( + model_params.rope_coords.flatten(0, 1) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + .to(device=tokens.device, dtype=tokens.dtype) + ) + # global assimilation engine and adapter - tokens = self.ae_global_engine(tokens, use_reentrant=False) + tokens = self.ae_global_engine(tokens, coords=coords, use_reentrant=False) return tokens diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index 88df67fa3..9778e3046 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -94,3 +94,79 @@ def positional_encoding_harmonic_coord(x, lats, lons): x = x + pe return x + + +#################################################################################################### +# Rotary positional embeddings (2D) adapted from Qwen3 & LLama for reuse in WeatherGenerator. +# https://github.com/qiuzh20/gated_attention/blob/main/modeling_qwen3.py +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q: Query tensor. + k: Key tensor. + cos: Cosine embedding tensor. + sin: Sine embedding tensor. + position_ids: Deprecated and unused; present for API compatibility. + unsqueeze_dim: Dimension along which to unsqueeze cos/sin for broadcasting. + """ + + cos = cos.unsqueeze(unsqueeze_dim).to(dtype=q.dtype) + sin = sin.unsqueeze(unsqueeze_dim).to(dtype=q.dtype) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed, k_embed + + +#################################################################################################### +def rotary_embedding_2d(coords, dim, base=10000.0): + """Create 2D RoPE embeddings from latitude/longitude coordinates. + + Args: + coords: Tensor of shape (..., 2) with coordinates in radians (lat, lon). + dim: Head dimension to encode; must be divisible by 4. + base: RoPE base frequency. + + Returns: + Tuple of (cos, sin) tensors with shape (..., dim). + """ + + if coords.shape[-1] != 2: + raise ValueError(f"coords last dimension must be 2 (lat, lon); got {coords.shape[-1]}") + if dim % 4 != 0: + raise ValueError(f"2D rotary embeddings require dim to be divisible by 4; got {dim}") + + # Split the rotary frequencies evenly between latitude and longitude to stay local to each cell. + half_dim = dim // 2 + inv_freq = 1.0 / ( + base ** (torch.arange(0, half_dim, 2, device=coords.device, dtype=coords.dtype) / half_dim) + ) + + lat, lon = coords.unbind(dim=-1) + freq_lat = lat.unsqueeze(-1) * inv_freq + freq_lon = lon.unsqueeze(-1) * inv_freq + + freqs = torch.cat((freq_lat, freq_lon), dim=-1) + emb = torch.cat((freqs, freqs), dim=-1) + + cos = torch.cos(emb) + sin = torch.sin(emb) + + return cos, sin + + +#################################################################################################### +def apply_rotary_pos_emb_2d(q, k, coords, base=10000.0, unsqueeze_dim=1): + """Convenience wrapper that builds 2D RoPE embeddings and applies them to q/k.""" + + cos, sin = rotary_embedding_2d(coords, q.shape[-1], base=base) + return apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim) From e4be7c3135734c5785df953de7c0a86220d4bbaf Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Tue, 9 Dec 2025 13:04:30 +0100 Subject: [PATCH 03/36] simplify assimilate global, forecast mode config --- config/default_config.yml | 12 ++++++------ src/weathergen/model/model.py | 3 +-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index c3d0cfc35..b9b4abfda 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -50,12 +50,12 @@ pred_mlp_adaln: True # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -forecast_offset : 0 +forecast_offset : 1 forecast_delta_hrs: 0 -forecast_steps: 0 -forecast_policy: null +forecast_steps: 2 +forecast_policy: fixed forecast_att_dense_rate: 1.0 -fe_num_blocks: 0 +fe_num_blocks: 8 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True @@ -98,7 +98,7 @@ ema_halflife_in_thousands: 1e-3 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 -training_mode: "masking" +training_mode: "forecast" training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]},} } # training_mode_config: {"loss": {LossPhysical: [['mse', 0.7]], @@ -126,7 +126,7 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], "same_strategy_per_batch": false } -num_mini_epochs: 32 +num_mini_epochs: 16 samples_per_mini_epoch: 4096 samples_per_validation: 512 diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 5c5a8b68d..b6652b727 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -850,13 +850,12 @@ def assimilate_global(self, model_params: ModelParams, tokens: torch.Tensor) -> Latent representation of the model """ - batch_size = tokens.shape[0] coords = None if self.use_2D_rope: coords = ( model_params.rope_coords.flatten(0, 1) .unsqueeze(0) - .repeat(batch_size, 1, 1) + .repeat(tokens.shape[0], 1, 1) .to(device=tokens.device, dtype=tokens.dtype) ) From b19365566a4a159817f41e528d4e41d906e42fe0 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Tue, 9 Dec 2025 13:16:40 +0100 Subject: [PATCH 04/36] add 2d rope to forecast eigine only once --- src/weathergen/model/engines.py | 7 ++++--- src/weathergen/model/model.py | 23 +++++++++++++++++++---- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index b10632ff2..b7a304ae2 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -454,12 +454,13 @@ def init_weights_final(m): for block in self.fe_blocks: block.apply(init_weights_final) - def forward(self, tokens, fstep): + def forward(self, tokens, fstep, coords=None): aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") for block in self.fe_blocks: if isinstance(block, (MultiSelfAttentionHead, MultiSelfAttentionHeadLocal)): - # No RoPE coords during forecasting; pass aux to AdaLayerNorm. - tokens = checkpoint(block, tokens, None, aux_info, use_reentrant=False) + # Pass coords if provided (for 2D RoPE on first forecast step), otherwise None + # Always pass aux to AdaLayerNorm. + tokens = checkpoint(block, tokens, coords, aux_info, use_reentrant=False) else: tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index b6652b727..94dc4e494 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -667,7 +667,10 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca if noise_std > 0.0: tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std - tokens = self.forecast(model_params, tokens, fstep) + # Apply 2D RoPE coords only on the first forecast step + # to help model understand spatial relationships before temporal roll-out + is_first_forecast = fstep == forecast_offset + tokens = self.forecast(model_params, tokens, fstep, apply_rope=is_first_forecast) # prediction for final step preds_all += [ @@ -865,20 +868,32 @@ def assimilate_global(self, model_params: ModelParams, tokens: torch.Tensor) -> return tokens ######################################### - def forecast(self, model_params: ModelParams, tokens: torch.Tensor, fstep: int) -> torch.Tensor: + def forecast( + self, model_params: ModelParams, tokens: torch.Tensor, fstep: int, apply_rope: bool = False + ) -> torch.Tensor: """Advances latent space representation in time Args: - model_params : Query and embedding parameters (never used) + model_params : Query and embedding parameters tokens : Input tokens to be processed by the model. fstep: Current forecast step index (can be used as aux info). + apply_rope: Whether to apply 2D RoPE coords (only on first forecast step). Returns: Processed tokens Raises: ValueError: For unexpected arguments in checkpoint method """ - tokens = self.forecast_engine(tokens, fstep) + coords = None + if apply_rope and self.use_2D_rope: + coords = ( + model_params.rope_coords.flatten(0, 1) + .unsqueeze(0) + .repeat(tokens.shape[0], 1, 1) + .to(device=tokens.device, dtype=tokens.dtype) + ) + + tokens = self.forecast_engine(tokens, fstep, coords=coords) return tokens From 77d95ede8216cbbdda71f194bbb228a08133b255 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Wed, 10 Dec 2025 14:04:38 +0100 Subject: [PATCH 05/36] only keep global & forecast engine add 2d rope --- src/weathergen/model/model.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 94dc4e494..4b2445be4 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -815,23 +815,15 @@ def assimilate_local( # (applying this here assumes batch_size=1) # permute to use ae_local_num_queries as the batchsize and no_of_tokens # as seq len for flash attention - - # create mask from cell lens - mask = cell_lens.to(torch.bool) - - coords_global = None - if self.use_2D_rope: - coords_global = ( - model_params.rope_coords[mask] - .to(device=tokens_global_unmasked.device, dtype=tokens_global_unmasked.dtype) - .permute(1, 0, 2) - ) tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) tokens_global_unmasked = self.ae_aggregation_engine( - tokens_global_unmasked, coords=coords_global, use_reentrant=False + tokens_global_unmasked, coords=None, use_reentrant=False ) tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) + # create mask from cell lens + mask = cell_lens.to(torch.bool) + # fill empty tensor using mask for positions of unmasked tokens tokens_global[mask] = tokens_global_unmasked.to(tokens_global.dtype) From 3928e593e4b65122be2d1298ec91b6ad1a9d44f7 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Wed, 10 Dec 2025 16:10:11 +0100 Subject: [PATCH 06/36] simplify the code --- src/weathergen/model/model.py | 45 ++++++++++++----------------------- 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 4b2445be4..a61bd1b0b 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -86,15 +86,15 @@ def __init__(self, cf) -> None: ### ROPE COORDS (for 2D RoPE when use_2D_rope=True) ### # Precompute per-cell center coordinates (lat, lon in radians) for 2D RoPE. - # Shape: (num_healpix_cells, ae_local_num_queries, 2) - self.use_2D_rope = getattr(cf, "use_2D_rope", False) + # Shape: (bs, num_healpix_cells * ae_local_num_queries, 2) + self.use_2D_rope = cf.use_2D_rope if self.use_2D_rope: self.rope_coords = torch.nn.Parameter( torch.zeros( - self.num_healpix_cells, - cf.ae_local_num_queries, + bs, + self.num_healpix_cells * cf.ae_local_num_queries, 2, - dtype=torch.float32, + dtype=self.dtype, ), requires_grad=False, ) @@ -156,6 +156,7 @@ def reset_parameters(self, cf: Config) -> "ModelParams": # positional encodings + bs = cf.batch_size_per_gpu dim_embed = cf.ae_local_dim_embed len_token_seq = 1024 self.pe_embed.data.fill_(0.0) @@ -171,10 +172,13 @@ def reset_parameters(self, cf: Config) -> "ModelParams": if self.use_2D_rope: # Precompute per-cell center coordinates (lat, lon in radians) for 2D RoPE. + # Shape: (num_healpix_cells, ae_local_num_queries, 2) verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) - self.rope_coords.data.copy_(coords) + # Transform to final shape: (bs, num_healpix_cells * ae_local_num_queries, 2) + coords_flat = coords.flatten(0, 1).unsqueeze(0).repeat(bs, 1, 1) + self.rope_coords.data.copy_(coords_flat) # Clear pe_global when using 2D RoPE self.pe_global.data.fill_(0.0) else: @@ -217,7 +221,6 @@ def reset_parameters(self, cf: Config) -> "ModelParams": # varlen index set for tokens assert cf.batch_size_per_gpu == cf.batch_size_validation_per_gpu - bs = cf.batch_size_per_gpu nqs = 9 s = [bs, self.num_healpix_cells, cf.ae_local_num_queries, cf.ae_global_dim_embed] if cf.target_cell_local_prediction: @@ -297,8 +300,6 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.sources_size = sources_size self.targets_num_channels = targets_num_channels self.targets_coords_size = targets_coords_size - - self.use_2D_rope = getattr(cf, "use_2D_rope", False) self.ae_aggregation_engine: QueryAggregationEngine | None = None self.ae_global_engine: GlobalAssimilationEngine | None = None @@ -668,7 +669,6 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std # Apply 2D RoPE coords only on the first forecast step - # to help model understand spatial relationships before temporal roll-out is_first_forecast = fstep == forecast_offset tokens = self.forecast(model_params, tokens, fstep, apply_rope=is_first_forecast) @@ -845,17 +845,8 @@ def assimilate_global(self, model_params: ModelParams, tokens: torch.Tensor) -> Latent representation of the model """ - coords = None - if self.use_2D_rope: - coords = ( - model_params.rope_coords.flatten(0, 1) - .unsqueeze(0) - .repeat(tokens.shape[0], 1, 1) - .to(device=tokens.device, dtype=tokens.dtype) - ) - # global assimilation engine and adapter - tokens = self.ae_global_engine(tokens, coords=coords, use_reentrant=False) + tokens = self.ae_global_engine(tokens, coords=model_params.rope_coords, use_reentrant=False) return tokens @@ -876,16 +867,10 @@ def forecast( ValueError: For unexpected arguments in checkpoint method """ - coords = None - if apply_rope and self.use_2D_rope: - coords = ( - model_params.rope_coords.flatten(0, 1) - .unsqueeze(0) - .repeat(tokens.shape[0], 1, 1) - .to(device=tokens.device, dtype=tokens.dtype) - ) - - tokens = self.forecast_engine(tokens, fstep, coords=coords) + if apply_rope and model_params.use_2D_rope: + tokens = self.forecast_engine(tokens, fstep, coords=model_params.rope_coords) + else: + tokens = self.forecast_engine(tokens, fstep, coords=None) return tokens From b54206791043f620849b3d75625387c4857835dd Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Wed, 10 Dec 2025 16:20:30 +0100 Subject: [PATCH 07/36] fix lint --- src/weathergen/model/engines.py | 6 +++--- src/weathergen/model/model.py | 25 +++++++++++++++++++------ 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index b7a304ae2..9d82c7eb2 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -305,7 +305,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: def forward(self, tokens, coords=None, use_reentrant=True): for block in self.ae_aggregation_blocks: - if isinstance(block, (MultiSelfAttentionHead, MultiSelfAttentionHeadLocal)): + if isinstance(block, MultiSelfAttentionHead | MultiSelfAttentionHeadLocal): tokens = checkpoint(block, tokens, coords, use_reentrant=use_reentrant) else: tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) @@ -376,7 +376,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: def forward(self, tokens, coords=None, use_reentrant=True): for block in self.ae_global_blocks: - if isinstance(block, (MultiSelfAttentionHead, MultiSelfAttentionHeadLocal)): + if isinstance(block, MultiSelfAttentionHead | MultiSelfAttentionHeadLocal): tokens = checkpoint(block, tokens, coords, use_reentrant=use_reentrant) else: tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) @@ -457,7 +457,7 @@ def init_weights_final(m): def forward(self, tokens, fstep, coords=None): aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") for block in self.fe_blocks: - if isinstance(block, (MultiSelfAttentionHead, MultiSelfAttentionHeadLocal)): + if isinstance(block, MultiSelfAttentionHead | MultiSelfAttentionHeadLocal): # Pass coords if provided (for 2D RoPE on first forecast step), otherwise None # Always pass aux to AdaLayerNorm. tokens = checkpoint(block, tokens, coords, aux_info, use_reentrant=False) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index a61bd1b0b..e68b64e31 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -169,7 +169,7 @@ def reset_parameters(self, cf: Config) -> "ModelParams": self.pe_embed.data[:, 1::2] = torch.cos(position * div[: self.pe_embed[:, 1::2].shape[1]]) dim_embed = cf.ae_global_dim_embed - + if self.use_2D_rope: # Precompute per-cell center coordinates (lat, lon in radians) for 2D RoPE. # Shape: (num_healpix_cells, ae_local_num_queries, 2) @@ -184,23 +184,36 @@ def reset_parameters(self, cf: Config) -> "ModelParams": else: # Original pe_global initialization self.pe_global.data.fill_(0.0) - xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed + xs = ( + 2.0 + * np.pi + * torch.arange(0, dim_embed, 2, device=self.pe_global.device) + / dim_embed + ) self.pe_global.data[..., 0::2] = 0.5 * torch.sin( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) + torch.outer( + 8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs + ) ) self.pe_global.data[..., 0::2] += ( torch.sin( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + torch.outer( + torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs + ) ) .unsqueeze(1) .repeat((1, cf.ae_local_num_queries, 1)) ) self.pe_global.data[..., 1::2] = 0.5 * torch.cos( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) + torch.outer( + 8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs + ) ) self.pe_global.data[..., 1::2] += ( torch.cos( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + torch.outer( + torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs + ) ) .unsqueeze(1) .repeat((1, cf.ae_local_num_queries, 1)) From 8538291507901367d9ad9fc78b9cbe54f405ead4 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Wed, 10 Dec 2025 16:24:56 +0100 Subject: [PATCH 08/36] small fix --- src/weathergen/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index e68b64e31..74ba54036 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -870,7 +870,7 @@ def forecast( """Advances latent space representation in time Args: - model_params : Query and embedding parameters + model_params : Query and embedding parameters (never used) tokens : Input tokens to be processed by the model. fstep: Current forecast step index (can be used as aux info). apply_rope: Whether to apply 2D RoPE coords (only on first forecast step). From 81a52b5b6866a6ac7e6d963c0314afa70b908b22 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Wed, 10 Dec 2025 16:30:03 +0100 Subject: [PATCH 09/36] fix annotation --- src/weathergen/model/engines.py | 2 -- src/weathergen/model/model.py | 4 +--- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 9d82c7eb2..f305c6bd6 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -458,8 +458,6 @@ def forward(self, tokens, fstep, coords=None): aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") for block in self.fe_blocks: if isinstance(block, MultiSelfAttentionHead | MultiSelfAttentionHeadLocal): - # Pass coords if provided (for 2D RoPE on first forecast step), otherwise None - # Always pass aux to AdaLayerNorm. tokens = checkpoint(block, tokens, coords, aux_info, use_reentrant=False) else: tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 74ba54036..d44863a51 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -85,8 +85,6 @@ def __init__(self, cf) -> None: self.pe_global = torch.nn.Parameter(pe, requires_grad=False) ### ROPE COORDS (for 2D RoPE when use_2D_rope=True) ### - # Precompute per-cell center coordinates (lat, lon in radians) for 2D RoPE. - # Shape: (bs, num_healpix_cells * ae_local_num_queries, 2) self.use_2D_rope = cf.use_2D_rope if self.use_2D_rope: self.rope_coords = torch.nn.Parameter( @@ -176,9 +174,9 @@ def reset_parameters(self, cf: Config) -> "ModelParams": verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) - # Transform to final shape: (bs, num_healpix_cells * ae_local_num_queries, 2) coords_flat = coords.flatten(0, 1).unsqueeze(0).repeat(bs, 1, 1) self.rope_coords.data.copy_(coords_flat) + # Clear pe_global when using 2D RoPE self.pe_global.data.fill_(0.0) else: From a8afd35d034a6597d357aa47878cd1a85a99517d Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Wed, 10 Dec 2025 16:32:30 +0100 Subject: [PATCH 10/36] fix lint --- src/weathergen/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index d44863a51..b6f639dd5 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -176,7 +176,7 @@ def reset_parameters(self, cf: Config) -> "ModelParams": coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) coords_flat = coords.flatten(0, 1).unsqueeze(0).repeat(bs, 1, 1) self.rope_coords.data.copy_(coords_flat) - + # Clear pe_global when using 2D RoPE self.pe_global.data.fill_(0.0) else: From 5a4889810f6674840c1545a30f97115a4997eebe Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Wed, 10 Dec 2025 16:35:53 +0100 Subject: [PATCH 11/36] add annotation --- src/weathergen/model/positional_encoding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index 9778e3046..f73faee1c 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -98,6 +98,7 @@ def positional_encoding_harmonic_coord(x, lats, lons): #################################################################################################### # Rotary positional embeddings (2D) adapted from Qwen3 & LLama for reuse in WeatherGenerator. +# rotate_half () and apply_rotary_pos_emb () from: # https://github.com/qiuzh20/gated_attention/blob/main/modeling_qwen3.py def rotate_half(x): """Rotates half the hidden dims of the input.""" From dc914ea02684b263ebc0161992f01169faa4eba9 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Wed, 10 Dec 2025 16:56:56 +0100 Subject: [PATCH 12/36] default config --- config/default_config.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index b9b4abfda..c3d0cfc35 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -50,12 +50,12 @@ pred_mlp_adaln: True # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -forecast_offset : 1 +forecast_offset : 0 forecast_delta_hrs: 0 -forecast_steps: 2 -forecast_policy: fixed +forecast_steps: 0 +forecast_policy: null forecast_att_dense_rate: 1.0 -fe_num_blocks: 8 +fe_num_blocks: 0 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True @@ -98,7 +98,7 @@ ema_halflife_in_thousands: 1e-3 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 -training_mode: "forecast" +training_mode: "masking" training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]},} } # training_mode_config: {"loss": {LossPhysical: [['mse', 0.7]], @@ -126,7 +126,7 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], "same_strategy_per_batch": false } -num_mini_epochs: 16 +num_mini_epochs: 32 samples_per_mini_epoch: 4096 samples_per_validation: 512 From d9e05047ba50ded1d8a98b449fc86b40b58302c8 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Wed, 10 Dec 2025 19:33:07 +0100 Subject: [PATCH 13/36] fix default use_reentrant --- src/weathergen/model/engines.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index f305c6bd6..0c7696162 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -303,7 +303,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) ) - def forward(self, tokens, coords=None, use_reentrant=True): + def forward(self, tokens, coords=None, use_reentrant=False): for block in self.ae_aggregation_blocks: if isinstance(block, MultiSelfAttentionHead | MultiSelfAttentionHeadLocal): tokens = checkpoint(block, tokens, coords, use_reentrant=use_reentrant) @@ -374,7 +374,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) ) - def forward(self, tokens, coords=None, use_reentrant=True): + def forward(self, tokens, coords=None, use_reentrant=False): for block in self.ae_global_blocks: if isinstance(block, MultiSelfAttentionHead | MultiSelfAttentionHeadLocal): tokens = checkpoint(block, tokens, coords, use_reentrant=use_reentrant) From 6e7f1edf9cdbf55be2410709cd1065302b0c14a2 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Thu, 11 Dec 2025 09:24:55 +0100 Subject: [PATCH 14/36] use_2d_rope false as defaut --- config/default_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index c3d0cfc35..9516ce036 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -66,7 +66,7 @@ healpix_level: 5 # Use 2D RoPE instead of traditional global positional encoding # When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) # When False: uses traditional pe_global positional encoding -use_2D_rope: True +use_2D_rope: False with_mixed_precision: True with_flash_attention: True From ba3f579e5aac8b6249174a0a2338d162571d49af Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Fri, 12 Dec 2025 15:46:12 +0100 Subject: [PATCH 15/36] Add copyright notice for RoPE functions and update naming - Add copyright attribution for rotate_half() and apply_rotary_pos_emb() functions - Rename apply_rotary_pos_emb_2d() to rotary_pos_emb_2d() for consistency - Rename config parameter use_2D_rope to rope_2D for better extensibility when supporting different RoPE variants in the future --- config/default_config.yml | 2 +- src/weathergen/model/attention.py | 4 ++-- src/weathergen/model/model.py | 13 ++++++------- src/weathergen/model/positional_encoding.py | 10 ++++++---- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 9516ce036..ef8ebd036 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -66,7 +66,7 @@ healpix_level: 5 # Use 2D RoPE instead of traditional global positional encoding # When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) # When False: uses traditional pe_global positional encoding -use_2D_rope: False +rope_2D: False with_mixed_precision: True with_flash_attention: True diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 227e8c47a..f5ac24e40 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -14,7 +14,7 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention from weathergen.model.norms import AdaLayerNorm, RMSNorm -from weathergen.model.positional_encoding import apply_rotary_pos_emb_2d +from weathergen.model.positional_encoding import rotary_pos_emb_2d class MultiSelfAttentionHeadVarlen(torch.nn.Module): @@ -255,7 +255,7 @@ def forward(self, x, coords=None, ada_ln_aux=None): vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3]) if coords is not None: - qs, ks = apply_rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1) + qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1) outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index b6f639dd5..e296f2280 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -84,9 +84,9 @@ def __init__(self, cf) -> None: ) self.pe_global = torch.nn.Parameter(pe, requires_grad=False) - ### ROPE COORDS (for 2D RoPE when use_2D_rope=True) ### - self.use_2D_rope = cf.use_2D_rope - if self.use_2D_rope: + ### ROPE COORDS ### + self.rope_2D = cf.get("rope_2D", False) + if self.rope_2D: self.rope_coords = torch.nn.Parameter( torch.zeros( bs, @@ -168,7 +168,7 @@ def reset_parameters(self, cf: Config) -> "ModelParams": dim_embed = cf.ae_global_dim_embed - if self.use_2D_rope: + if self.rope_2D: # Precompute per-cell center coordinates (lat, lon in radians) for 2D RoPE. # Shape: (num_healpix_cells, ae_local_num_queries, 2) verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) @@ -680,8 +680,7 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std # Apply 2D RoPE coords only on the first forecast step - is_first_forecast = fstep == forecast_offset - tokens = self.forecast(model_params, tokens, fstep, apply_rope=is_first_forecast) + tokens = self.forecast(model_params, tokens, fstep, apply_rope = fstep == forecast_offset) # prediction for final step preds_all += [ @@ -878,7 +877,7 @@ def forecast( ValueError: For unexpected arguments in checkpoint method """ - if apply_rope and model_params.use_2D_rope: + if apply_rope and model_params.rope_2D: tokens = self.forecast_engine(tokens, fstep, coords=model_params.rope_coords) else: tokens = self.forecast_engine(tokens, fstep, coords=None) diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index f73faee1c..b690b9b92 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -97,9 +97,11 @@ def positional_encoding_harmonic_coord(x, lats, lons): #################################################################################################### -# Rotary positional embeddings (2D) adapted from Qwen3 & LLama for reuse in WeatherGenerator. -# rotate_half () and apply_rotary_pos_emb () from: -# https://github.com/qiuzh20/gated_attention/blob/main/modeling_qwen3.py +# The functions rotate_half() and apply_rotary_pos_emb() below are derived from LLaMA and Qwen3 models, +# originally developed by Meta Platforms, Inc., The Qwen team, Alibaba Group and the HuggingFace Inc. team, +# licensed under the Apache License, Version 2.0. +# Source: https://github.com/qiuzh20/gated_attention/blob/main/modeling_qwen3.py + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -166,7 +168,7 @@ def rotary_embedding_2d(coords, dim, base=10000.0): #################################################################################################### -def apply_rotary_pos_emb_2d(q, k, coords, base=10000.0, unsqueeze_dim=1): +def rotary_pos_emb_2d(q, k, coords, base=10000.0, unsqueeze_dim=1): """Convenience wrapper that builds 2D RoPE embeddings and applies them to q/k.""" cos, sin = rotary_embedding_2d(coords, q.shape[-1], base=base) From 7e0aff21f610ae4e50bb8e128857e13d562f07d1 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Fri, 12 Dec 2025 15:51:41 +0100 Subject: [PATCH 16/36] fix lint --- src/weathergen/model/attention.py | 2 +- src/weathergen/model/model.py | 2 +- src/weathergen/model/positional_encoding.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index f5ac24e40..293be661c 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -544,7 +544,7 @@ def forward(self, x, coords=None, ada_ln_aux=None): vs = self.proj_heads_v(x).reshape(s).to(self.dtype) if coords is not None: - qs, ks = apply_rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=2) + qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=2) # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index e296f2280..3e63b5879 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -680,7 +680,7 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std # Apply 2D RoPE coords only on the first forecast step - tokens = self.forecast(model_params, tokens, fstep, apply_rope = fstep == forecast_offset) + tokens = self.forecast(model_params, tokens, fstep, apply_rope=fstep == forecast_offset) # prediction for final step preds_all += [ diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index b690b9b92..434ab9f65 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -102,6 +102,7 @@ def positional_encoding_harmonic_coord(x, lats, lons): # licensed under the Apache License, Version 2.0. # Source: https://github.com/qiuzh20/gated_attention/blob/main/modeling_qwen3.py + def rotate_half(x): """Rotates half the hidden dims of the input.""" From 2bc76dd2c6a978aaa53aeb696d255155929e9877 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Fri, 12 Dec 2025 16:01:19 +0100 Subject: [PATCH 17/36] fix lint --- src/weathergen/model/positional_encoding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index 434ab9f65..a2ed95317 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -97,9 +97,9 @@ def positional_encoding_harmonic_coord(x, lats, lons): #################################################################################################### -# The functions rotate_half() and apply_rotary_pos_emb() below are derived from LLaMA and Qwen3 models, -# originally developed by Meta Platforms, Inc., The Qwen team, Alibaba Group and the HuggingFace Inc. team, -# licensed under the Apache License, Version 2.0. +# The functions rotate_half() and apply_rotary_pos_emb() below are derived from LLaMA and Qwen3 +# models, originally developed by Meta Platforms, Inc., The Qwen team, Alibaba Group and the +# HuggingFace Inc. team, licensed under the Apache License, Version 2.0. # Source: https://github.com/qiuzh20/gated_attention/blob/main/modeling_qwen3.py From 11761bd651cf94804c5c48c164baf5e19828df9d Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Sun, 14 Dec 2025 22:47:00 +0100 Subject: [PATCH 18/36] add 2d rope to all forecast steps --- src/weathergen/model/model.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 3e63b5879..addb5e669 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -679,8 +679,7 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca if noise_std > 0.0: tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std - # Apply 2D RoPE coords only on the first forecast step - tokens = self.forecast(model_params, tokens, fstep, apply_rope=fstep == forecast_offset) + tokens = self.forecast(model_params, tokens, fstep) # prediction for final step preds_all += [ @@ -861,26 +860,20 @@ def assimilate_global(self, model_params: ModelParams, tokens: torch.Tensor) -> return tokens ######################################### - def forecast( - self, model_params: ModelParams, tokens: torch.Tensor, fstep: int, apply_rope: bool = False - ) -> torch.Tensor: + def forecast(self, model_params: ModelParams, tokens: torch.Tensor, fstep: int) -> torch.Tensor: """Advances latent space representation in time Args: model_params : Query and embedding parameters (never used) tokens : Input tokens to be processed by the model. fstep: Current forecast step index (can be used as aux info). - apply_rope: Whether to apply 2D RoPE coords (only on first forecast step). Returns: Processed tokens Raises: ValueError: For unexpected arguments in checkpoint method """ - if apply_rope and model_params.rope_2D: - tokens = self.forecast_engine(tokens, fstep, coords=model_params.rope_coords) - else: - tokens = self.forecast_engine(tokens, fstep, coords=None) + tokens = self.forecast_engine(tokens, fstep, coords=model_params.rope_coords) return tokens From c6938fb34b0ac034c8fc5936cad26089ad340593 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Tue, 30 Dec 2025 12:05:44 +0100 Subject: [PATCH 19/36] more confs --- src/weathergen/model/model.py | 172 ---------------------------------- 1 file changed, 172 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 959f96cc7..5641db9fd 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -21,12 +21,9 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config -<<<<<<< HEAD from weathergen.datasets.utils import healpix_verts_rots, r3tos2 -======= from weathergen.datasets.batch import ModelBatch from weathergen.model.encoder import EncoderModule ->>>>>>> develop from weathergen.model.engines import ( EnsPredictionHead, ForecastingEngine, @@ -642,176 +639,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # prediction for final step output = self.predict(model_params, batch.get_forecast_steps(), tokens, batch, output) -<<<<<<< HEAD - latents = {} - latents["posteriors"] = posteriors - - return ModelOutput(physical=preds_all, latent=latents) - - ######################################### - def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: - """Embeds input data for each stream separately and rearranges it to cell-wise order - Args: - model_params : Query and embedding parameters - streams_data : Used to initialize first tokens for pre-processing - Returns: - Tokens for local assimilation - """ - - device = next(self.parameters()).device - tokens_all = self.embed_engine(streams_data, model_params.pe_embed, self.dtype, device) - - return tokens_all - - ######################################### - def assimilate_local( - self, model_params: ModelParams, tokens: torch.Tensor, cell_lens: torch.Tensor - ) -> torch.Tensor: - """Processes embedded tokens locally and prepares them for the global assimilation - Args: - model_params : Query and embedding parameters - tokens : Input tokens to be processed by local assimilation - cell_lens : Used to identify range of tokens to use from generated tokens in cell - embedding - Returns: - Tokens for global assimilation - """ - - batch_size = ( - self.cf.batch_size_per_gpu if self.training else self.cf.batch_size_validation_per_gpu - ) - - s = self.q_cells.shape - # print( f'{np.prod(np.array(tokens.shape))} :: {np.prod(np.array(s))}' - # + ':: {np.prod(np.array(tokens.shape))/np.prod(np.array(s))}') - # TODO: test if positional encoding is needed here - if self.cf.ae_local_queries_per_cell: - tokens_global = (self.q_cells + model_params.pe_global).repeat(batch_size, 1, 1) - else: - tokens_global = ( - self.q_cells.repeat(self.num_healpix_cells, 1, 1) + model_params.pe_global - ) - q_cells_lens = torch.cat( - [model_params.q_cells_lens[0].unsqueeze(0)] - + [model_params.q_cells_lens[1:] for _ in range(batch_size)] - ) - - # local assimilation model - # for block in self.ae_local_blocks: - # tokens = checkpoint(block, tokens, cell_lens, use_reentrant=False) - - # if self.cf.latent_noise_kl_weight > 0.0: - # tokens, posteriors = self.interpolate_latents.interpolate_with_noise( - # tokens, sampling=self.training - # ) - # else: - # tokens, posteriors = tokens, 0.0 - - # for block in self.ae_adapter: - # tokens_global = checkpoint( - # block, - # tokens_global, - # tokens, - # q_cells_lens, - # cell_lens, - # use_reentrant=False, - # ) - - # work around to bug in flash attention for hl>=5 - - cell_lens = cell_lens[1:] - clen = self.num_healpix_cells // (2 if self.cf.healpix_level <= 5 else 8) - tokens_global_unmasked_all = [] - posteriors = [] - zero_pad = torch.zeros(1, device=tokens.device, dtype=torch.int32) - for i in range((cell_lens.shape[0]) // clen): - # make sure we properly catch all elements in last chunk - i_end = (i + 1) * clen if i < (cell_lens.shape[0] // clen) - 1 else cell_lens.shape[0] - l0, l1 = ( - (0 if i == 0 else cell_lens[: i * clen].cumsum(0)[-1]), - cell_lens[:i_end].cumsum(0)[-1], - ) - - tokens_c = tokens[l0:l1] - tokens_global_c = tokens_global[i * clen : i_end] - cell_lens_c = torch.cat([zero_pad, cell_lens[i * clen : i_end]]) - q_cells_lens_c = q_cells_lens[: cell_lens_c.shape[0]] - - # local assimilation model - tokens_c = self.ae_local_engine(tokens_c, cell_lens_c, use_reentrant=False) - - if self.cf.latent_noise_kl_weight > 0.0: - tokens_c, posteriors_c = self.interpolate_latents.interpolate_with_noise( - tokens_c, sampling=self.training - ) - posteriors += [posteriors_c] - else: - tokens_c, posteriors = tokens_c, 0.0 - - # create mask for global tokens, without first element (used for padding) - mask_c = cell_lens_c[1:].to(torch.bool) - tokens_global_unmasked_c = tokens_global_c[mask_c] - q_cells_lens_unmasked_c = torch.cat([zero_pad, q_cells_lens_c[1:][mask_c]]) - cell_lens_unmasked_c = torch.cat([zero_pad, cell_lens_c[1:][mask_c]]) - - if l0 == l1 or tokens_c.shape[0] == 0: - tokens_global_unmasked_all += [tokens_global_unmasked_c] - continue - - # local to global adapter engine - tokens_global_unmasked_c = self.ae_local_global_engine( - tokens_c, - tokens_global_unmasked_c, - q_cells_lens_unmasked_c, - cell_lens_unmasked_c, - use_reentrant=False, - ) - - tokens_global_unmasked_all += [tokens_global_unmasked_c] - - tokens_global_unmasked = torch.cat(tokens_global_unmasked_all) - - # query aggregation engine on the query tokens in unmasked cells - # (applying this here assumes batch_size=1) - # permute to use ae_local_num_queries as the batchsize and no_of_tokens - # as seq len for flash attention - tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) - tokens_global_unmasked = self.ae_aggregation_engine( - tokens_global_unmasked, coords=None, use_reentrant=False - ) - tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) - - # create mask from cell lens - mask = cell_lens.to(torch.bool) - - # fill empty tensor using mask for positions of unmasked tokens - tokens_global[mask] = tokens_global_unmasked.to(tokens_global.dtype) - - # recover batch dimension and build global token list - tokens_global = ( - tokens_global.reshape([batch_size, self.num_healpix_cells, s[-2], s[-1]]) - + model_params.pe_global - ).flatten(1, 2) - - return tokens_global, posteriors - - ######################################### - def assimilate_global(self, model_params: ModelParams, tokens: torch.Tensor) -> torch.Tensor: - """Performs transformer based global assimilation in latent space - Args: - model_params : Query and embedding parameters (never used) - tokens : Input tokens to be pre-processed by global assimilation - Returns: - Latent representation of the model - """ - - # global assimilation engine and adapter - tokens = self.ae_global_engine(tokens, coords=model_params.rope_coords, use_reentrant=False) - - return tokens -======= return output ->>>>>>> develop ######################################### def forecast(self, model_params: ModelParams, tokens: torch.Tensor, fstep: int) -> torch.Tensor: From cca2c23d652d340bbb15ca92a96a2417049b7424 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Tue, 30 Dec 2025 12:08:46 +0100 Subject: [PATCH 20/36] add missing enumerate --- src/weathergen/model/engines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 05b69d3e8..5f3b9dbd6 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -471,7 +471,7 @@ def init_weights_final(m): def forward(self, tokens, fstep, coords=None): aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") - for _b_idx, block in self.fe_blocks: + for _b_idx, block in enumerate(self.fe_blocks): if isinstance(block, torch.nn.modules.normalization.LayerNorm): tokens = block(tokens) elif isinstance(block, MultiSelfAttentionHead | MultiSelfAttentionHeadLocal): From 34378c657bc81a4e5d595998d0727e9a39789576 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Tue, 30 Dec 2025 14:12:04 +0100 Subject: [PATCH 21/36] def forecast config --- config/default_forecast_config.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/config/default_forecast_config.yml b/config/default_forecast_config.yml index c7822d327..98b778a02 100644 --- a/config/default_forecast_config.yml +++ b/config/default_forecast_config.yml @@ -63,6 +63,8 @@ fe_impute_latent_noise_std: 1e-4 # 1e-4 healpix_level: 5 +rope_2D: False + with_mixed_precision: True with_flash_attention: True compile_model: False From e39f56bcbccb876cebf885c9117ac14b7dcc5417 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Tue, 30 Dec 2025 15:18:30 +0100 Subject: [PATCH 22/36] aux_info=None in Forecast Eng forward --- src/weathergen/model/engines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 5f3b9dbd6..99921cced 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -470,7 +470,7 @@ def init_weights_final(m): block.apply(init_weights_final) def forward(self, tokens, fstep, coords=None): - aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") + aux_info = None for _b_idx, block in enumerate(self.fe_blocks): if isinstance(block, torch.nn.modules.normalization.LayerNorm): tokens = block(tokens) From 3abf1886a8a9e9b25aa8845bef4b89a8af394d71 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Tue, 30 Dec 2025 15:19:49 +0100 Subject: [PATCH 23/36] lint --- src/weathergen/datasets/multi_stream_data_sampler.py | 2 +- src/weathergen/model/engines.py | 2 +- src/weathergen/model/model.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index b6bf869c2..c3e211129 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -808,4 +808,4 @@ def worker_workset(self): + f" : dataset [{local_start},{local_end}) : [{iter_start},{iter_end})" ) - return iter_start, iter_end \ No newline at end of file + return iter_start, iter_end diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 99921cced..39ebb4724 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -478,7 +478,7 @@ def forward(self, tokens, fstep, coords=None): tokens = checkpoint(block, tokens, coords, aux_info, use_reentrant=False) else: tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) - + return tokens diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 5641db9fd..3fd3d74e2 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -21,8 +21,8 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config -from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.datasets.batch import ModelBatch +from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.encoder import EncoderModule from weathergen.model.engines import ( EnsPredictionHead, From d3cebcf71d2cc54f81ce6a3270b7b7dde7701674 Mon Sep 17 00:00:00 2001 From: wang85 Date: Fri, 23 Jan 2026 13:33:53 +0100 Subject: [PATCH 24/36] add rope to global engine, which was moved to encoder --- src/weathergen/model/encoder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index c8efe49f3..704884857 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -120,7 +120,7 @@ def forward(self, model_params, batch): global_tokens, posteriors = self.assimilate_local(model_params, stream_cell_tokens, batch) - global_tokens = self.assimilate_global(global_tokens) + global_tokens = self.assimilate_global(model_params, global_tokens) return global_tokens, posteriors @@ -251,7 +251,7 @@ def assimilate_local( return tokens_global, posteriors - def assimilate_global(self, tokens: torch.Tensor) -> torch.Tensor: + def assimilate_global(self, model_params, tokens: torch.Tensor) -> torch.Tensor: """Performs transformer based global assimilation in latent space Args: model_params : Query and embedding parameters (never used) @@ -261,6 +261,6 @@ def assimilate_global(self, tokens: torch.Tensor) -> torch.Tensor: """ # global assimilation engine and adapter - tokens = self.ae_global_engine(tokens, use_reentrant=False) + tokens = self.ae_global_engine(tokens, coords=model_params.rope_coords, use_reentrant=False) return tokens From 131de8ae6a9c2e4339e0f5a87b3bd256b445711c Mon Sep 17 00:00:00 2001 From: wang85 Date: Fri, 23 Jan 2026 15:17:19 +0100 Subject: [PATCH 25/36] 1)init attention module with_2d_rope and rope_learnable_freq 2) add code for learnable frequencey --- config/default_config.yml | 1 + src/weathergen/model/attention.py | 58 +++++++++++++++++++-- src/weathergen/model/engines.py | 12 +++++ src/weathergen/model/positional_encoding.py | 27 ++++------ 4 files changed, 76 insertions(+), 22 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index ad0ce6dba..ba4d0f8fc 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -70,6 +70,7 @@ healpix_level: 5 # When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) # When False: uses traditional pe_global positional encoding rope_2D: False +rope_learnable_freq: False with_mixed_precision: True with_flash_attention: True diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 293be661c..069160d39 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -198,6 +198,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + with_2d_rope=False, + rope_learnable_freq=False, ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -233,6 +235,27 @@ def __init__( self.dtype = attention_dtype assert with_flash, "Only flash attention supported." + if rope_learnable_freq and not with_2d_rope: + raise ValueError("rope_learnable_freq requires with_2d_rope=True") + + self.with_2d_rope = with_2d_rope + self.rope_learnable_freq = rope_learnable_freq + if self.with_2d_rope and (self.dim_head_proj % 4 != 0): + raise ValueError( + f"2D rotary embeddings require dim to be divisible by 4; got {self.dim_head_proj}" + ) + + half_dim = self.dim_head_proj // 2 + base = 10000.0 + inv_freq_lat = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim)) + inv_freq_lon = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim)) + if self.rope_learnable_freq: + self.rope_inv_freq_lat = torch.nn.Parameter(inv_freq_lat.clone()) + self.rope_inv_freq_lon = torch.nn.Parameter(inv_freq_lon.clone()) + else: + self.register_buffer("rope_inv_freq_lat", inv_freq_lat) + self.register_buffer("rope_inv_freq_lon", inv_freq_lon) + # define block mask def mask_block_local(batch, head, idx_q, idx_kv): return (idx_q // block_factor) == (idx_kv // block_factor) @@ -254,8 +277,10 @@ def forward(self, x, coords=None, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3]) vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3]) - if coords is not None: - qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1) + if self.with_2d_rope: + if coords is None: + raise ValueError("coords must be provided when with_2d_rope=True") + qs, ks = rotary_pos_emb_2d(qs, ks, coords, self.rope_inv_freq_lat, self.rope_inv_freq_lon, unsqueeze_dim=1) outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) @@ -491,6 +516,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + with_2d_rope=False, + rope_learnable_freq=False, ): super(MultiSelfAttentionHead, self).__init__() @@ -531,6 +558,27 @@ def __init__( self.att = self.attention self.softmax = torch.nn.Softmax(dim=-1) + if rope_learnable_freq and not with_2d_rope: + raise ValueError("rope_learnable_freq requires with_2d_rope=True") + + self.with_2d_rope = with_2d_rope + self.rope_learnable_freq = rope_learnable_freq + if self.with_2d_rope and (self.dim_head_proj % 4 != 0): + raise ValueError( + f"2D rotary embeddings require dim to be divisible by 4; got {self.dim_head_proj}" + ) + + half_dim = self.dim_head_proj // 2 + base = 10000.0 + inv_freq_lat = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim)) + inv_freq_lon = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim)) + if self.rope_learnable_freq: + self.rope_inv_freq_lat = torch.nn.Parameter(inv_freq_lat.clone()) + self.rope_inv_freq_lon = torch.nn.Parameter(inv_freq_lon.clone()) + else: + self.register_buffer("rope_inv_freq_lat", inv_freq_lat) + self.register_buffer("rope_inv_freq_lon", inv_freq_lon) + def forward(self, x, coords=None, ada_ln_aux=None): if self.with_residual: x_in = x @@ -543,8 +591,10 @@ def forward(self, x, coords=None, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) vs = self.proj_heads_v(x).reshape(s).to(self.dtype) - if coords is not None: - qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=2) + if self.with_2d_rope: + if coords is None: + raise ValueError("coords must be provided when with_2d_rope=True") + qs, ks = rotary_pos_emb_2d(qs, ks, coords, self.rope_inv_freq_lat, self.rope_inv_freq_lon, unsqueeze_dim=2) # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 39ebb4724..9e5ec8f51 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -271,6 +271,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_2d_rope=self.cf.rope_2D, + rope_learnable_freq=self.cf.rope_learnable_freq, ) ) else: @@ -286,6 +288,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_2d_rope=self.cf.rope_2D, + rope_learnable_freq=self.cf.rope_learnable_freq, ) ) # MLP block @@ -342,6 +346,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_2d_rope=self.cf.rope_2D, + rope_learnable_freq=self.cf.rope_learnable_freq, ) ) else: @@ -357,6 +363,8 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_2d_rope=self.cf.rope_2D, + rope_learnable_freq=self.cf.rope_learnable_freq, ) ) # MLP block @@ -416,6 +424,8 @@ def __init__(self, cf: Config, num_healpix_cells: int, dim_aux: int = None) -> N dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_2d_rope=self.cf.rope_2D, + rope_learnable_freq=self.cf.rope_learnable_freq, ) ) else: @@ -432,6 +442,8 @@ def __init__(self, cf: Config, num_healpix_cells: int, dim_aux: int = None) -> N dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_2d_rope=self.cf.rope_2D, + rope_learnable_freq=self.cf.rope_learnable_freq, ) ) # Add MLP block diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index a2ed95317..ccd2eecc4 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -111,7 +111,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -119,7 +119,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k: Key tensor. cos: Cosine embedding tensor. sin: Sine embedding tensor. - position_ids: Deprecated and unused; present for API compatibility. unsqueeze_dim: Dimension along which to unsqueeze cos/sin for broadcasting. """ @@ -132,32 +131,24 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): #################################################################################################### -def rotary_embedding_2d(coords, dim, base=10000.0): +def rotary_embedding_2d(coords, inv_freq_lat, inv_freq_lon): """Create 2D RoPE embeddings from latitude/longitude coordinates. Args: coords: Tensor of shape (..., 2) with coordinates in radians (lat, lon). - dim: Head dimension to encode; must be divisible by 4. - base: RoPE base frequency. + inv_freq_lat: Inverse frequency tensor for latitude. + inv_freq_lon: Inverse frequency tensor for longitude. Returns: - Tuple of (cos, sin) tensors with shape (..., dim). + Tuple of (cos, sin) tensors with shape (..., dim_head). """ if coords.shape[-1] != 2: raise ValueError(f"coords last dimension must be 2 (lat, lon); got {coords.shape[-1]}") - if dim % 4 != 0: - raise ValueError(f"2D rotary embeddings require dim to be divisible by 4; got {dim}") - - # Split the rotary frequencies evenly between latitude and longitude to stay local to each cell. - half_dim = dim // 2 - inv_freq = 1.0 / ( - base ** (torch.arange(0, half_dim, 2, device=coords.device, dtype=coords.dtype) / half_dim) - ) lat, lon = coords.unbind(dim=-1) - freq_lat = lat.unsqueeze(-1) * inv_freq - freq_lon = lon.unsqueeze(-1) * inv_freq + freq_lat = lat.unsqueeze(-1) * inv_freq_lat + freq_lon = lon.unsqueeze(-1) * inv_freq_lon freqs = torch.cat((freq_lat, freq_lon), dim=-1) emb = torch.cat((freqs, freqs), dim=-1) @@ -169,8 +160,8 @@ def rotary_embedding_2d(coords, dim, base=10000.0): #################################################################################################### -def rotary_pos_emb_2d(q, k, coords, base=10000.0, unsqueeze_dim=1): +def rotary_pos_emb_2d(q, k, coords, inv_freq_lat, inv_freq_lon, unsqueeze_dim=1): """Convenience wrapper that builds 2D RoPE embeddings and applies them to q/k.""" - cos, sin = rotary_embedding_2d(coords, q.shape[-1], base=base) + cos, sin = rotary_embedding_2d(coords, inv_freq_lat, inv_freq_lon) return apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim) From 599254193ca4f05dac95c64c728325c54da782a6 Mon Sep 17 00:00:00 2001 From: wang85 Date: Fri, 23 Jan 2026 16:18:49 +0100 Subject: [PATCH 26/36] solve some reviews --- src/weathergen/model/attention.py | 54 +++++++++------------ src/weathergen/model/engines.py | 16 ++---- src/weathergen/model/model.py | 4 +- src/weathergen/model/positional_encoding.py | 18 ++++++- 4 files changed, 46 insertions(+), 46 deletions(-) diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 069160d39..7a1e64cc8 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -14,8 +14,14 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention from weathergen.model.norms import AdaLayerNorm, RMSNorm -from weathergen.model.positional_encoding import rotary_pos_emb_2d +from weathergen.model.positional_encoding import build_rope_inv_freq_2d, rotary_pos_emb_2d +""" +Attention blocks used by WeatherGenerator. + +Some blocks optionally apply 2D RoPE. When enabled, the caller must provide per-token 2D +coordinates aligned with the token order (lat, lon in radians). +""" class MultiSelfAttentionHeadVarlen(torch.nn.Module): def __init__( @@ -240,21 +246,14 @@ def __init__( self.with_2d_rope = with_2d_rope self.rope_learnable_freq = rope_learnable_freq - if self.with_2d_rope and (self.dim_head_proj % 4 != 0): - raise ValueError( - f"2D rotary embeddings require dim to be divisible by 4; got {self.dim_head_proj}" - ) - - half_dim = self.dim_head_proj // 2 - base = 10000.0 - inv_freq_lat = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim)) - inv_freq_lon = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim)) - if self.rope_learnable_freq: - self.rope_inv_freq_lat = torch.nn.Parameter(inv_freq_lat.clone()) - self.rope_inv_freq_lon = torch.nn.Parameter(inv_freq_lon.clone()) - else: - self.register_buffer("rope_inv_freq_lat", inv_freq_lat) - self.register_buffer("rope_inv_freq_lon", inv_freq_lon) + if self.with_2d_rope: + inv_freq_lat, inv_freq_lon = build_rope_inv_freq_2d(self.dim_head_proj) + if self.rope_learnable_freq: + self.rope_inv_freq_lat = torch.nn.Parameter(inv_freq_lat) + self.rope_inv_freq_lon = torch.nn.Parameter(inv_freq_lon) + else: + self.register_buffer("rope_inv_freq_lat", inv_freq_lat) + self.register_buffer("rope_inv_freq_lon", inv_freq_lon) # define block mask def mask_block_local(batch, head, idx_q, idx_kv): @@ -563,21 +562,14 @@ def __init__( self.with_2d_rope = with_2d_rope self.rope_learnable_freq = rope_learnable_freq - if self.with_2d_rope and (self.dim_head_proj % 4 != 0): - raise ValueError( - f"2D rotary embeddings require dim to be divisible by 4; got {self.dim_head_proj}" - ) - - half_dim = self.dim_head_proj // 2 - base = 10000.0 - inv_freq_lat = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim)) - inv_freq_lon = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim)) - if self.rope_learnable_freq: - self.rope_inv_freq_lat = torch.nn.Parameter(inv_freq_lat.clone()) - self.rope_inv_freq_lon = torch.nn.Parameter(inv_freq_lon.clone()) - else: - self.register_buffer("rope_inv_freq_lat", inv_freq_lat) - self.register_buffer("rope_inv_freq_lon", inv_freq_lon) + if self.with_2d_rope: + inv_freq_lat, inv_freq_lon = build_rope_inv_freq_2d(self.dim_head_proj) + if self.rope_learnable_freq: + self.rope_inv_freq_lat = torch.nn.Parameter(inv_freq_lat) + self.rope_inv_freq_lon = torch.nn.Parameter(inv_freq_lon) + else: + self.register_buffer("rope_inv_freq_lat", inv_freq_lat) + self.register_buffer("rope_inv_freq_lon", inv_freq_lon) def forward(self, x, coords=None, ada_ln_aux=None): if self.with_residual: diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 9e5ec8f51..971cf6b2d 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -306,11 +306,9 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) def forward(self, tokens, coords=None, use_reentrant=False): + ada_ln_aux = None for block in self.ae_aggregation_blocks: - if isinstance(block, MultiSelfAttentionHead | MultiSelfAttentionHeadLocal): - tokens = checkpoint(block, tokens, coords, use_reentrant=use_reentrant) - else: - tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) + tokens = checkpoint(block, tokens, coords, ada_ln_aux, use_reentrant=use_reentrant) return tokens @@ -385,11 +383,9 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) def forward(self, tokens, coords=None, use_reentrant=False): + ada_ln_aux = None for block in self.ae_global_blocks: - if isinstance(block, MultiSelfAttentionHead | MultiSelfAttentionHeadLocal): - tokens = checkpoint(block, tokens, coords, use_reentrant=use_reentrant) - else: - tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) + tokens = checkpoint(block, tokens, coords, ada_ln_aux, use_reentrant=use_reentrant) return tokens @@ -486,10 +482,8 @@ def forward(self, tokens, fstep, coords=None): for _b_idx, block in enumerate(self.fe_blocks): if isinstance(block, torch.nn.modules.normalization.LayerNorm): tokens = block(tokens) - elif isinstance(block, MultiSelfAttentionHead | MultiSelfAttentionHeadLocal): - tokens = checkpoint(block, tokens, coords, aux_info, use_reentrant=False) else: - tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) + tokens = checkpoint(block, tokens, coords, aux_info, use_reentrant=False) return tokens diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 3fd3d74e2..9237f41f2 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -105,14 +105,14 @@ def __init__(self, cf) -> None: ### ROPE COORDS ### self.rope_2D = cf.get("rope_2D", False) if self.rope_2D: - self.rope_coords = torch.nn.Parameter( + self.register_buffer( + "rope_coords", torch.zeros( bs, self.num_healpix_cells * cf.ae_local_num_queries, 2, dtype=self.dtype, ), - requires_grad=False, ) else: self.rope_coords = None diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index ccd2eecc4..c63909547 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -130,6 +130,19 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): return q_embed, k_embed +#################################################################################################### +def build_rope_inv_freq_2d(dim_head, base=10000.0, device=None, dtype=None): + """Build inverse frequencies for 2D RoPE.""" + assert dim_head % 4 == 0, ( + f"2D rotary embeddings require dim to be divisible by 4; got {dim_head}" + ) + half_dim = dim_head // 2 + inv_freq = 1.0 / ( + base ** (torch.arange(0, half_dim, 2, device=device, dtype=dtype) / half_dim) + ) + return inv_freq, inv_freq.clone() + + #################################################################################################### def rotary_embedding_2d(coords, inv_freq_lat, inv_freq_lon): """Create 2D RoPE embeddings from latitude/longitude coordinates. @@ -143,8 +156,9 @@ def rotary_embedding_2d(coords, inv_freq_lat, inv_freq_lon): Tuple of (cos, sin) tensors with shape (..., dim_head). """ - if coords.shape[-1] != 2: - raise ValueError(f"coords last dimension must be 2 (lat, lon); got {coords.shape[-1]}") + assert coords.shape[-1] == 2, ( + f"coords last dimension must be 2 (lat, lon); got {coords.shape[-1]}" + ) lat, lon = coords.unbind(dim=-1) freq_lat = lat.unsqueeze(-1) * inv_freq_lat From e7ccc235dfc7a1b87037b52423eb26588c12a87c Mon Sep 17 00:00:00 2001 From: wang85 Date: Fri, 23 Jan 2026 18:47:34 +0100 Subject: [PATCH 27/36] fix lint --- src/weathergen/model/attention.py | 11 ++++++++--- src/weathergen/model/positional_encoding.py | 4 +--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 7a1e64cc8..0e4f230cb 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -23,6 +23,7 @@ coordinates aligned with the token order (lat, lon in radians). """ + class MultiSelfAttentionHeadVarlen(torch.nn.Module): def __init__( self, @@ -254,7 +255,7 @@ def __init__( else: self.register_buffer("rope_inv_freq_lat", inv_freq_lat) self.register_buffer("rope_inv_freq_lon", inv_freq_lon) - + # define block mask def mask_block_local(batch, head, idx_q, idx_kv): return (idx_q // block_factor) == (idx_kv // block_factor) @@ -279,7 +280,9 @@ def forward(self, x, coords=None, ada_ln_aux=None): if self.with_2d_rope: if coords is None: raise ValueError("coords must be provided when with_2d_rope=True") - qs, ks = rotary_pos_emb_2d(qs, ks, coords, self.rope_inv_freq_lat, self.rope_inv_freq_lon, unsqueeze_dim=1) + qs, ks = rotary_pos_emb_2d( + qs, ks, coords, self.rope_inv_freq_lat, self.rope_inv_freq_lon, unsqueeze_dim=1 + ) outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) @@ -586,7 +589,9 @@ def forward(self, x, coords=None, ada_ln_aux=None): if self.with_2d_rope: if coords is None: raise ValueError("coords must be provided when with_2d_rope=True") - qs, ks = rotary_pos_emb_2d(qs, ks, coords, self.rope_inv_freq_lat, self.rope_inv_freq_lon, unsqueeze_dim=2) + qs, ks = rotary_pos_emb_2d( + qs, ks, coords, self.rope_inv_freq_lat, self.rope_inv_freq_lon, unsqueeze_dim=2 + ) # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index c63909547..b4ba356dd 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -137,9 +137,7 @@ def build_rope_inv_freq_2d(dim_head, base=10000.0, device=None, dtype=None): f"2D rotary embeddings require dim to be divisible by 4; got {dim_head}" ) half_dim = dim_head // 2 - inv_freq = 1.0 / ( - base ** (torch.arange(0, half_dim, 2, device=device, dtype=dtype) / half_dim) - ) + inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, device=device, dtype=dtype) / half_dim)) return inv_freq, inv_freq.clone() From c4604f33eeea4fba44171bd7839d191d43ecaddd Mon Sep 17 00:00:00 2001 From: wang85 Date: Fri, 23 Jan 2026 19:52:11 +0100 Subject: [PATCH 28/36] fix 2 bugs: remove rope in QueryAggregation, and change bs in model.py --- src/weathergen/model/engines.py | 4 ---- src/weathergen/model/model.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 2364dc847..bf3f92f30 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -280,8 +280,6 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - with_2d_rope=self.cf.rope_2D, - rope_learnable_freq=self.cf.rope_learnable_freq, ) ) else: @@ -297,8 +295,6 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - with_2d_rope=self.cf.rope_2D, - rope_learnable_freq=self.cf.rope_learnable_freq, ) ) # MLP block diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index a831df30c..88596e0f4 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -110,7 +110,7 @@ def __init__(self, cf) -> None: self.register_buffer( "rope_coords", torch.zeros( - bs, + cf.batch_size_per_gpu, self.num_healpix_cells * cf.ae_local_num_queries, 2, dtype=self.dtype, From 1a64cd4e1a67b7bcf65e18b81a3e0c01b805df8b Mon Sep 17 00:00:00 2001 From: wang85 Date: Fri, 30 Jan 2026 21:49:32 +0100 Subject: [PATCH 29/36] temporally remove learnable rope --- config/default_config.yml | 1 - src/weathergen/model/attention.py | 42 +++----------------- src/weathergen/model/engines.py | 4 -- src/weathergen/model/positional_encoding.py | 43 ++++++++++----------- 4 files changed, 25 insertions(+), 65 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 0262f0068..737a81eda 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -70,7 +70,6 @@ healpix_level: 5 # When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) # When False: uses traditional pe_global positional encoding rope_2D: False -rope_learnable_freq: False with_mixed_precision: True with_flash_attention: True diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 0e4f230cb..dd3307d8b 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -14,7 +14,7 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention from weathergen.model.norms import AdaLayerNorm, RMSNorm -from weathergen.model.positional_encoding import build_rope_inv_freq_2d, rotary_pos_emb_2d +from weathergen.model.positional_encoding import rotary_pos_emb_2d """ Attention blocks used by WeatherGenerator. @@ -206,7 +206,6 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, with_2d_rope=False, - rope_learnable_freq=False, ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -214,6 +213,7 @@ def __init__( self.with_flash = with_flash self.softcap = softcap self.with_residual = with_residual + self.with_2d_rope = with_2d_rope assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -242,20 +242,6 @@ def __init__( self.dtype = attention_dtype assert with_flash, "Only flash attention supported." - if rope_learnable_freq and not with_2d_rope: - raise ValueError("rope_learnable_freq requires with_2d_rope=True") - - self.with_2d_rope = with_2d_rope - self.rope_learnable_freq = rope_learnable_freq - if self.with_2d_rope: - inv_freq_lat, inv_freq_lon = build_rope_inv_freq_2d(self.dim_head_proj) - if self.rope_learnable_freq: - self.rope_inv_freq_lat = torch.nn.Parameter(inv_freq_lat) - self.rope_inv_freq_lon = torch.nn.Parameter(inv_freq_lon) - else: - self.register_buffer("rope_inv_freq_lat", inv_freq_lat) - self.register_buffer("rope_inv_freq_lon", inv_freq_lon) - # define block mask def mask_block_local(batch, head, idx_q, idx_kv): return (idx_q // block_factor) == (idx_kv // block_factor) @@ -280,9 +266,7 @@ def forward(self, x, coords=None, ada_ln_aux=None): if self.with_2d_rope: if coords is None: raise ValueError("coords must be provided when with_2d_rope=True") - qs, ks = rotary_pos_emb_2d( - qs, ks, coords, self.rope_inv_freq_lat, self.rope_inv_freq_lon, unsqueeze_dim=1 - ) + qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1) outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) @@ -519,7 +503,6 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, with_2d_rope=False, - rope_learnable_freq=False, ): super(MultiSelfAttentionHead, self).__init__() @@ -528,6 +511,7 @@ def __init__( self.softcap = softcap self.dropout_rate = dropout_rate self.with_residual = with_residual + self.with_2d_rope = with_2d_rope assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -560,20 +544,6 @@ def __init__( self.att = self.attention self.softmax = torch.nn.Softmax(dim=-1) - if rope_learnable_freq and not with_2d_rope: - raise ValueError("rope_learnable_freq requires with_2d_rope=True") - - self.with_2d_rope = with_2d_rope - self.rope_learnable_freq = rope_learnable_freq - if self.with_2d_rope: - inv_freq_lat, inv_freq_lon = build_rope_inv_freq_2d(self.dim_head_proj) - if self.rope_learnable_freq: - self.rope_inv_freq_lat = torch.nn.Parameter(inv_freq_lat) - self.rope_inv_freq_lon = torch.nn.Parameter(inv_freq_lon) - else: - self.register_buffer("rope_inv_freq_lat", inv_freq_lat) - self.register_buffer("rope_inv_freq_lon", inv_freq_lon) - def forward(self, x, coords=None, ada_ln_aux=None): if self.with_residual: x_in = x @@ -589,9 +559,7 @@ def forward(self, x, coords=None, ada_ln_aux=None): if self.with_2d_rope: if coords is None: raise ValueError("coords must be provided when with_2d_rope=True") - qs, ks = rotary_pos_emb_2d( - qs, ks, coords, self.rope_inv_freq_lat, self.rope_inv_freq_lon, unsqueeze_dim=2 - ) + qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=2) # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index bf3f92f30..a645305f9 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -350,7 +350,6 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.rope_2D, - rope_learnable_freq=self.cf.rope_learnable_freq, ) ) else: @@ -367,7 +366,6 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.rope_2D, - rope_learnable_freq=self.cf.rope_learnable_freq, ) ) # MLP block @@ -426,7 +424,6 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.rope_2D, - rope_learnable_freq=self.cf.rope_learnable_freq, ) ) else: @@ -444,7 +441,6 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), with_2d_rope=self.cf.rope_2D, - rope_learnable_freq=self.cf.rope_learnable_freq, ) ) # Add MLP block diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index b4ba356dd..42f5bc19b 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -111,7 +111,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -119,6 +119,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): k: Key tensor. cos: Cosine embedding tensor. sin: Sine embedding tensor. + position_ids: Deprecated and unused; present for API compatibility. unsqueeze_dim: Dimension along which to unsqueeze cos/sin for broadcasting. """ @@ -131,36 +132,32 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): #################################################################################################### -def build_rope_inv_freq_2d(dim_head, base=10000.0, device=None, dtype=None): - """Build inverse frequencies for 2D RoPE.""" - assert dim_head % 4 == 0, ( - f"2D rotary embeddings require dim to be divisible by 4; got {dim_head}" - ) - half_dim = dim_head // 2 - inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, device=device, dtype=dtype) / half_dim)) - return inv_freq, inv_freq.clone() - - -#################################################################################################### -def rotary_embedding_2d(coords, inv_freq_lat, inv_freq_lon): +def rotary_embedding_2d(coords, dim, base=10000.0): """Create 2D RoPE embeddings from latitude/longitude coordinates. Args: coords: Tensor of shape (..., 2) with coordinates in radians (lat, lon). - inv_freq_lat: Inverse frequency tensor for latitude. - inv_freq_lon: Inverse frequency tensor for longitude. + dim: Head dimension to encode; must be divisible by 4. + base: RoPE base frequency. Returns: - Tuple of (cos, sin) tensors with shape (..., dim_head). + Tuple of (cos, sin) tensors with shape (..., dim). """ - assert coords.shape[-1] == 2, ( - f"coords last dimension must be 2 (lat, lon); got {coords.shape[-1]}" + if coords.shape[-1] != 2: + raise ValueError(f"coords last dimension must be 2 (lat, lon); got {coords.shape[-1]}") + if dim % 4 != 0: + raise ValueError(f"2D rotary embeddings require dim to be divisible by 4; got {dim}") + + # Split the rotary frequencies evenly between latitude and longitude to stay local to each cell. + half_dim = dim // 2 + inv_freq = 1.0 / ( + base ** (torch.arange(0, half_dim, 2, device=coords.device, dtype=coords.dtype) / half_dim) ) lat, lon = coords.unbind(dim=-1) - freq_lat = lat.unsqueeze(-1) * inv_freq_lat - freq_lon = lon.unsqueeze(-1) * inv_freq_lon + freq_lat = lat.unsqueeze(-1) * inv_freq + freq_lon = lon.unsqueeze(-1) * inv_freq freqs = torch.cat((freq_lat, freq_lon), dim=-1) emb = torch.cat((freqs, freqs), dim=-1) @@ -172,8 +169,8 @@ def rotary_embedding_2d(coords, inv_freq_lat, inv_freq_lon): #################################################################################################### -def rotary_pos_emb_2d(q, k, coords, inv_freq_lat, inv_freq_lon, unsqueeze_dim=1): +def rotary_pos_emb_2d(q, k, coords, base=10000.0, unsqueeze_dim=1): """Convenience wrapper that builds 2D RoPE embeddings and applies them to q/k.""" - cos, sin = rotary_embedding_2d(coords, inv_freq_lat, inv_freq_lon) - return apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim) + cos, sin = rotary_embedding_2d(coords, q.shape[-1], base=base) + return apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim) \ No newline at end of file From 00a7fdd1b52bc6e06abb1292f6f2ab52d6683ff3 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Fri, 30 Jan 2026 23:25:44 +0100 Subject: [PATCH 30/36] add rope for register and class tokens; fix lint --- config/default_config.yml | 2 ++ src/weathergen/model/encoder.py | 7 ++++++- src/weathergen/model/model.py | 10 ++++++++-- src/weathergen/model/positional_encoding.py | 2 +- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 11c5d7e4b..a9c957a32 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -70,6 +70,8 @@ healpix_level: 5 # When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) # When False: uses traditional pe_global positional encoding rope_2D: False +batch_size_per_gpu: 1 +batch_size_validation_per_gpu: 1 with_mixed_precision: True with_flash_attention: True diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 93482c6f2..b15f3ce86 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -125,7 +125,12 @@ def forward(self, model_params, batch): self.assimilate_local, model_params, stream_cell_tokens, batch, use_reentrant=False ) - tokens_global = checkpoint(self.ae_global_engine, tokens_global, coords=model_params.rope_coords, use_reentrant=False) + tokens_global = checkpoint( + self.ae_global_engine, + tokens_global, + coords=model_params.rope_coords, + use_reentrant=False, + ) return tokens_global, posteriors diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 225ec154e..211d4150b 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -108,11 +108,15 @@ def __init__(self, cf) -> None: ### ROPE COORDS ### self.rope_2D = cf.get("rope_2D", False) if self.rope_2D: + self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens + total_tokens = ( + self.num_healpix_cells + self.num_extra_tokens + ) * cf.ae_local_num_queries self.register_buffer( "rope_coords", torch.zeros( cf.batch_size_per_gpu, - self.num_healpix_cells * cf.ae_local_num_queries, + total_tokens, 2, dtype=self.dtype, ), @@ -186,7 +190,9 @@ def reset_parameters(self, cf: Config) -> "ModelParams": coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) coords_flat = coords.flatten(0, 1).unsqueeze(0).repeat(bs, 1, 1) - self.rope_coords.data.copy_(coords_flat) + offset = self.num_extra_tokens * cf.ae_local_num_queries + self.rope_coords.data.fill_(0.0) + self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) # Clear pe_global when using 2D RoPE self.pe_global.data.fill_(0.0) diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index 42f5bc19b..a2ed95317 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -173,4 +173,4 @@ def rotary_pos_emb_2d(q, k, coords, base=10000.0, unsqueeze_dim=1): """Convenience wrapper that builds 2D RoPE embeddings and applies them to q/k.""" cos, sin = rotary_embedding_2d(coords, q.shape[-1], base=base) - return apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim) \ No newline at end of file + return apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim) From d1634f90eefc0fba77757783e1f0dd9702e59a70 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Fri, 30 Jan 2026 23:55:39 +0100 Subject: [PATCH 31/36] rename aux_info in queryaggregation --- src/weathergen/model/engines.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 0af4c8db9..683c6f33d 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -306,11 +306,11 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: def forward(self, tokens, batch_lens, use_reentrant, coords=None): for block in self.ae_aggregation_blocks: - ada_ln_aux = None + aux_info = None if isinstance(block, MultiSelfAttentionHeadVarlen): tokens = block(tokens, x_lens=batch_lens) else: - tokens = block(tokens, coords, ada_ln_aux) + tokens = block(tokens, coords, aux_info) return tokens From b60224ebced78841f57751c92f3831649e3476d9 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Sat, 31 Jan 2026 08:52:04 +0100 Subject: [PATCH 32/36] remove position_ids and change raise valueError to assert --- src/weathergen/model/positional_encoding.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index a2ed95317..b3eac2ec4 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -111,7 +111,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -119,7 +119,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k: Key tensor. cos: Cosine embedding tensor. sin: Sine embedding tensor. - position_ids: Deprecated and unused; present for API compatibility. unsqueeze_dim: Dimension along which to unsqueeze cos/sin for broadcasting. """ @@ -144,11 +143,13 @@ def rotary_embedding_2d(coords, dim, base=10000.0): Tuple of (cos, sin) tensors with shape (..., dim). """ - if coords.shape[-1] != 2: - raise ValueError(f"coords last dimension must be 2 (lat, lon); got {coords.shape[-1]}") - if dim % 4 != 0: - raise ValueError(f"2D rotary embeddings require dim to be divisible by 4; got {dim}") - + assert coords.shape[-1] == 2, ( + f"coords last dimension must be 2 (lat, lon); got {coords.shape[-1]}" + ) + assert dim % 4 == 0, ( + f"2D rotary embeddings require dim to be divisible by 4; got {dim}" + ) + # Split the rotary frequencies evenly between latitude and longitude to stay local to each cell. half_dim = dim // 2 inv_freq = 1.0 / ( From 7efef08c28b98ad07151598a5a9fabb2705cac10 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Sat, 31 Jan 2026 09:31:39 +0100 Subject: [PATCH 33/36] batch size get from get_batch_size_from_config() --- src/weathergen/model/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 211d4150b..b79658168 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -36,6 +36,7 @@ ) from weathergen.model.layers import MLP, NamedLinear from weathergen.model.utils import get_num_parameters +from weathergen.train.utils import get_batch_size_from_config from weathergen.utils.distributed import is_root from weathergen.utils.utils import get_dtype @@ -112,10 +113,11 @@ def __init__(self, cf) -> None: total_tokens = ( self.num_healpix_cells + self.num_extra_tokens ) * cf.ae_local_num_queries + batch_size_per_gpu = get_batch_size_from_config(cf.training_config) self.register_buffer( "rope_coords", torch.zeros( - cf.batch_size_per_gpu, + batch_size_per_gpu, total_tokens, 2, dtype=self.dtype, @@ -169,7 +171,7 @@ def reset_parameters(self, cf: Config) -> "ModelParams": # positional encodings - bs = cf.batch_size_per_gpu + bs = get_batch_size_from_config(cf.training_config) dim_embed = cf.ae_local_dim_embed len_token_seq = 1024 self.pe_embed.data.fill_(0.0) From ecf75b23af52f31defac49c7fba018e6446a992d Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Sat, 31 Jan 2026 09:34:30 +0100 Subject: [PATCH 34/36] revert to default config without batchsize --- config/default_config.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index a9c957a32..11c5d7e4b 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -70,8 +70,6 @@ healpix_level: 5 # When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) # When False: uses traditional pe_global positional encoding rope_2D: False -batch_size_per_gpu: 1 -batch_size_validation_per_gpu: 1 with_mixed_precision: True with_flash_attention: True From cdd6088a8b0f9adde6d0e09cb344955904eb11c8 Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Sat, 31 Jan 2026 09:42:57 +0100 Subject: [PATCH 35/36] use self.batch_size --- src/weathergen/model/model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index b79658168..7f21d6dc9 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -91,6 +91,7 @@ def __init__(self, cf) -> None: self.healpix_level = cf.healpix_level self.num_healpix_cells = 12 * 4**cf.healpix_level self.dtype = get_dtype(cf.attention_dtype) + self.batch_size_per_gpu = get_batch_size_from_config(cf.training_config) ### POSITIONAL EMBEDDINGS ### len_token_seq = 1024 @@ -113,11 +114,10 @@ def __init__(self, cf) -> None: total_tokens = ( self.num_healpix_cells + self.num_extra_tokens ) * cf.ae_local_num_queries - batch_size_per_gpu = get_batch_size_from_config(cf.training_config) self.register_buffer( "rope_coords", torch.zeros( - batch_size_per_gpu, + self.batch_size_per_gpu, total_tokens, 2, dtype=self.dtype, @@ -171,7 +171,6 @@ def reset_parameters(self, cf: Config) -> "ModelParams": # positional encodings - bs = get_batch_size_from_config(cf.training_config) dim_embed = cf.ae_local_dim_embed len_token_seq = 1024 self.pe_embed.data.fill_(0.0) @@ -191,7 +190,7 @@ def reset_parameters(self, cf: Config) -> "ModelParams": verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) - coords_flat = coords.flatten(0, 1).unsqueeze(0).repeat(bs, 1, 1) + coords_flat = coords.flatten(0, 1).unsqueeze(0).repeat(self.batch_size_per_gpu, 1, 1) offset = self.num_extra_tokens * cf.ae_local_num_queries self.rope_coords.data.fill_(0.0) self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) From b76ceb147ee1002e1e814a8731a2598c2e1df22f Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Sat, 31 Jan 2026 09:44:18 +0100 Subject: [PATCH 36/36] fix lint --- src/weathergen/model/positional_encoding.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index b3eac2ec4..411942a9f 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -146,10 +146,8 @@ def rotary_embedding_2d(coords, dim, base=10000.0): assert coords.shape[-1] == 2, ( f"coords last dimension must be 2 (lat, lon); got {coords.shape[-1]}" ) - assert dim % 4 == 0, ( - f"2D rotary embeddings require dim to be divisible by 4; got {dim}" - ) - + assert dim % 4 == 0, f"2D rotary embeddings require dim to be divisible by 4; got {dim}" + # Split the rotary frequencies evenly between latitude and longitude to stay local to each cell. half_dim = dim // 2 inv_freq = 1.0 / (