diff --git a/config/default_config.yml b/config/default_config.yml index 613078ebe..9b3442d7a 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -66,6 +66,11 @@ forecast_att_dense_rate: 1.0 healpix_level: 5 +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: False + with_mixed_precision: True with_flash_attention: True compile_model: False diff --git a/config/default_forecast_config.yml b/config/default_forecast_config.yml index 0080ed252..44f891a00 100644 --- a/config/default_forecast_config.yml +++ b/config/default_forecast_config.yml @@ -63,6 +63,8 @@ fe_impute_latent_noise_std: 1e-4 # 1e-4 healpix_level: 5 +rope_2D: False + with_mixed_precision: True with_flash_attention: True compile_model: False diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 99606bdce..3ee450b6b 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -14,6 +14,14 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention from weathergen.model.norms import AdaLayerNorm, RMSNorm +from weathergen.model.positional_encoding import rotary_pos_emb_2d + +""" +Attention blocks used by WeatherGenerator. + +Some blocks optionally apply 2D RoPE. When enabled, the caller must provide per-token 2D +coordinates aligned with the token order (lat, lon in radians). +""" class MultiSelfAttentionHeadVarlen(torch.nn.Module): @@ -197,6 +205,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + with_2d_rope=False, ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -204,6 +213,7 @@ def __init__( self.with_flash = with_flash self.softcap = softcap self.with_residual = with_residual + self.with_2d_rope = with_2d_rope assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -242,7 +252,7 @@ def mask_block_local(batch, head, idx_q, idx_kv): # compile for efficiency self.flex_attention = torch.compile(flex_attention, dynamic=False) - def forward(self, x, ada_ln_aux=None): + def forward(self, x, coords=None, ada_ln_aux=None): if self.with_residual: x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) @@ -253,6 +263,11 @@ def forward(self, x, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3]) vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3]) + if self.with_2d_rope: + if coords is None: + raise ValueError("coords must be provided when with_2d_rope=True") + qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1) + outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) out = self.proj_out(self.dropout(outs.flatten(-2, -1))) @@ -487,6 +502,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + with_2d_rope=False, ): super(MultiSelfAttentionHead, self).__init__() @@ -495,6 +511,7 @@ def __init__( self.softcap = softcap self.dropout_rate = dropout_rate self.with_residual = with_residual + self.with_2d_rope = with_2d_rope assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -527,7 +544,7 @@ def __init__( self.att = self.attention self.softmax = torch.nn.Softmax(dim=-1) - def forward(self, x, ada_ln_aux=None): + def forward(self, x, coords=None, ada_ln_aux=None): if self.with_residual: x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) @@ -539,6 +556,11 @@ def forward(self, x, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) vs = self.proj_heads_v(x).reshape(s).to(self.dtype) + if self.with_2d_rope: + if coords is None: + raise ValueError("coords must be provided when with_2d_rope=True") + qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=2) + # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 47e059014..b15f3ce86 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -125,7 +125,12 @@ def forward(self, model_params, batch): self.assimilate_local, model_params, stream_cell_tokens, batch, use_reentrant=False ) - tokens_global = checkpoint(self.ae_global_engine, tokens_global, use_reentrant=False) + tokens_global = checkpoint( + self.ae_global_engine, + tokens_global, + coords=model_params.rope_coords, + use_reentrant=False, + ) return tokens_global, posteriors diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index de5328a93..b4615fd11 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -304,12 +304,13 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) ) - def forward(self, tokens, batch_lens, use_reentrant): + def forward(self, tokens, batch_lens, use_reentrant, coords=None): for block in self.ae_aggregation_blocks: + aux_info = None if isinstance(block, MultiSelfAttentionHeadVarlen): tokens = block(tokens, x_lens=batch_lens) else: - tokens = block(tokens) + tokens = block(tokens, coords, aux_info) return tokens @@ -345,6 +346,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_2d_rope=self.cf.rope_2D, ) ) else: @@ -360,6 +362,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_2d_rope=self.cf.rope_2D, ) ) # MLP block @@ -379,9 +382,10 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) ) - def forward(self, tokens): + def forward(self, tokens, coords=None): + aux_info = None for block in self.ae_global_blocks: - tokens = block(tokens) + tokens = block(tokens, coords, aux_info) return tokens @@ -416,6 +420,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_2d_rope=self.cf.rope_2D, ) ) else: @@ -432,6 +437,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_2d_rope=self.cf.rope_2D, ) ) # Add MLP block @@ -461,7 +467,7 @@ def init_weights_final(m): for block in self.fe_blocks: block.apply(init_weights_final) - def forward(self, tokens, fstep): + def forward(self, tokens, fstep, coords=None): if self.training: # Impute noise to the latent state noise_std = self.cf.get("fe_impute_latent_noise_std", 0.0) @@ -473,7 +479,7 @@ def forward(self, tokens, fstep): if isinstance(block, torch.nn.modules.normalization.LayerNorm): tokens = block(tokens) else: - tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) + tokens = checkpoint(block, tokens, coords, aux_info, use_reentrant=False) return tokens diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 4ba54a483..273c28838 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -21,6 +21,7 @@ from weathergen.common.config import Config from weathergen.datasets.batch import ModelBatch +from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.encoder import EncoderModule from weathergen.model.engines import ( BilinearDecoder, @@ -35,6 +36,7 @@ ) from weathergen.model.layers import MLP, NamedLinear from weathergen.model.utils import get_num_parameters +from weathergen.train.utils import get_batch_size_from_config from weathergen.utils.distributed import is_root from weathergen.utils.utils import get_dtype, is_stream_forcing @@ -89,6 +91,7 @@ def __init__(self, cf) -> None: self.healpix_level = cf.healpix_level self.num_healpix_cells = 12 * 4**cf.healpix_level self.dtype = get_dtype(cf.attention_dtype) + self.batch_size_per_gpu = get_batch_size_from_config(cf.training_config) ### POSITIONAL EMBEDDINGS ### len_token_seq = 1024 @@ -104,6 +107,25 @@ def __init__(self, cf) -> None: ) self.pe_global = torch.nn.Parameter(pe, requires_grad=False) + ### ROPE COORDS ### + self.rope_2D = cf.get("rope_2D", False) + if self.rope_2D: + self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens + total_tokens = ( + self.num_healpix_cells + self.num_extra_tokens + ) * cf.ae_local_num_queries + self.register_buffer( + "rope_coords", + torch.zeros( + self.batch_size_per_gpu, + total_tokens, + 2, + dtype=self.dtype, + ), + ) + else: + self.rope_coords = None + ### HEALPIX NEIGHBOURS ### hlc = self.healpix_level with warnings.catch_warnings(action="ignore"): @@ -161,28 +183,57 @@ def reset_parameters(self, cf: Config) -> "ModelParams": self.pe_embed.data[:, 1::2] = torch.cos(position * div[: self.pe_embed[:, 1::2].shape[1]]) dim_embed = cf.ae_global_dim_embed - self.pe_global.data.fill_(0.0) - xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed - self.pe_global.data[..., 0::2] = 0.5 * torch.sin( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) - ) - self.pe_global.data[..., 0::2] += ( - torch.sin( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + + if self.rope_2D: + # Precompute per-cell center coordinates (lat, lon in radians) for 2D RoPE. + # Shape: (num_healpix_cells, ae_local_num_queries, 2) + verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) + coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) + coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) + coords_flat = coords.flatten(0, 1).unsqueeze(0).repeat(self.batch_size_per_gpu, 1, 1) + offset = self.num_extra_tokens * cf.ae_local_num_queries + self.rope_coords.data.fill_(0.0) + self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) + + # Clear pe_global when using 2D RoPE + self.pe_global.data.fill_(0.0) + else: + # Original pe_global initialization + self.pe_global.data.fill_(0.0) + xs = ( + 2.0 + * np.pi + * torch.arange(0, dim_embed, 2, device=self.pe_global.device) + / dim_embed ) - .unsqueeze(1) - .repeat((1, cf.ae_local_num_queries, 1)) - ) - self.pe_global.data[..., 1::2] = 0.5 * torch.cos( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) - ) - self.pe_global.data[..., 1::2] += ( - torch.cos( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + self.pe_global.data[..., 0::2] = 0.5 * torch.sin( + torch.outer( + 8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs + ) + ) + self.pe_global.data[..., 0::2] += ( + torch.sin( + torch.outer( + torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs + ) + ) + .unsqueeze(1) + .repeat((1, cf.ae_local_num_queries, 1)) + ) + self.pe_global.data[..., 1::2] = 0.5 * torch.cos( + torch.outer( + 8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs + ) + ) + self.pe_global.data[..., 1::2] += ( + torch.cos( + torch.outer( + torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs + ) + ) + .unsqueeze(1) + .repeat((1, cf.ae_local_num_queries, 1)) ) - .unsqueeze(1) - .repeat((1, cf.ae_local_num_queries, 1)) - ) # healpix neighborhood structure @@ -588,7 +639,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: for step in batch.get_output_idxs(): # apply forecasting engine (if present) if self.forecast_engine: - tokens = self.forecast_engine(tokens, step) + tokens = self.forecast_engine(tokens, step, coords=model_params.rope_coords) # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index 88df67fa3..411942a9f 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -94,3 +94,82 @@ def positional_encoding_harmonic_coord(x, lats, lons): x = x + pe return x + + +#################################################################################################### +# The functions rotate_half() and apply_rotary_pos_emb() below are derived from LLaMA and Qwen3 +# models, originally developed by Meta Platforms, Inc., The Qwen team, Alibaba Group and the +# HuggingFace Inc. team, licensed under the Apache License, Version 2.0. +# Source: https://github.com/qiuzh20/gated_attention/blob/main/modeling_qwen3.py + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q: Query tensor. + k: Key tensor. + cos: Cosine embedding tensor. + sin: Sine embedding tensor. + unsqueeze_dim: Dimension along which to unsqueeze cos/sin for broadcasting. + """ + + cos = cos.unsqueeze(unsqueeze_dim).to(dtype=q.dtype) + sin = sin.unsqueeze(unsqueeze_dim).to(dtype=q.dtype) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed, k_embed + + +#################################################################################################### +def rotary_embedding_2d(coords, dim, base=10000.0): + """Create 2D RoPE embeddings from latitude/longitude coordinates. + + Args: + coords: Tensor of shape (..., 2) with coordinates in radians (lat, lon). + dim: Head dimension to encode; must be divisible by 4. + base: RoPE base frequency. + + Returns: + Tuple of (cos, sin) tensors with shape (..., dim). + """ + + assert coords.shape[-1] == 2, ( + f"coords last dimension must be 2 (lat, lon); got {coords.shape[-1]}" + ) + assert dim % 4 == 0, f"2D rotary embeddings require dim to be divisible by 4; got {dim}" + + # Split the rotary frequencies evenly between latitude and longitude to stay local to each cell. + half_dim = dim // 2 + inv_freq = 1.0 / ( + base ** (torch.arange(0, half_dim, 2, device=coords.device, dtype=coords.dtype) / half_dim) + ) + + lat, lon = coords.unbind(dim=-1) + freq_lat = lat.unsqueeze(-1) * inv_freq + freq_lon = lon.unsqueeze(-1) * inv_freq + + freqs = torch.cat((freq_lat, freq_lon), dim=-1) + emb = torch.cat((freqs, freqs), dim=-1) + + cos = torch.cos(emb) + sin = torch.sin(emb) + + return cos, sin + + +#################################################################################################### +def rotary_pos_emb_2d(q, k, coords, base=10000.0, unsqueeze_dim=1): + """Convenience wrapper that builds 2D RoPE embeddings and applies them to q/k.""" + + cos, sin = rotary_embedding_2d(coords, q.shape[-1], base=base) + return apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim)