diff --git a/fme/core/registry/module.py b/fme/core/registry/module.py index b1ee41d79..5b229a570 100644 --- a/fme/core/registry/module.py +++ b/fme/core/registry/module.py @@ -150,6 +150,10 @@ def __post_init__(self): ) self._instance = self.registry.get(self.type, self.config) + @property + def module_config(self) -> ModuleConfig: + return self._instance + @classmethod def register( cls, type_name: str @@ -161,7 +165,7 @@ def build( n_in_channels: int, n_out_channels: int, dataset_info: DatasetInfo, - ) -> nn.Module: + ) -> Module: """ Build a nn.Module given information about the input and output channels and the dataset. diff --git a/fme/core/registry/test_module_registry.py b/fme/core/registry/test_module_registry.py index d2dce58e9..43a3240d9 100644 --- a/fme/core/registry/test_module_registry.py +++ b/fme/core/registry/test_module_registry.py @@ -2,15 +2,18 @@ import datetime import pathlib from collections.abc import Iterable +from typing import Any import dacite import pytest import torch +import yaml import fme from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates from fme.core.dataset_info import DatasetInfo from fme.core.labels import LabelEncoding +from fme.core.rand import set_seed from fme.core.registry.module import Module from .module import CONDITIONAL_BUILDERS, ModuleConfig, ModuleSelector @@ -79,8 +82,27 @@ def test_module_selector_raises_with_bad_config(): ModuleSelector(type="mock", config={"non_existent_key": 1}) -def get_noise_conditioned_sfno_module_selector() -> ModuleSelector: - return ModuleSelector( +def get_dbc2925_ncsfno_module() -> tuple[ModuleSelector, Module]: + img_shape = (9, 18) + n_in_channels = 5 + n_out_channels = 6 + all_labels = {"a", "b"} + timestep = datetime.timedelta(hours=6) + device = fme.get_device() + horizontal_coordinate = LatLonCoordinates( + lat=torch.zeros(img_shape[0], device=device), + lon=torch.zeros(img_shape[1], device=device), + ) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7, device=device), bk=torch.arange(7, device=device) + ) + dataset_info = DatasetInfo( + horizontal_coordinates=horizontal_coordinate, + vertical_coordinate=vertical_coordinate, + timestep=timestep, + all_labels=all_labels, + ) + selector = ModuleSelector( type="NoiseConditionedSFNO", config={ "embed_dim": 8, @@ -94,33 +116,15 @@ def get_noise_conditioned_sfno_module_selector() -> ModuleSelector: "spectral_transform": "sht", }, ) + module = selector.build( + n_in_channels=n_in_channels, + n_out_channels=n_out_channels, + dataset_info=dataset_info, + ) + return selector, module -def load_or_cache_state(selector_name: str, module: Module) -> dict[str, torch.Tensor]: - state_dict_path = DATA_DIR / f"{selector_name}_state_dict.pt" - if state_dict_path.exists(): - return torch.load(state_dict_path) - else: - state_dict = module.get_state() - torch.save(state_dict, state_dict_path) - raise RuntimeError( - f"State dict for {selector_name} not found. " - f"Created a new one at {state_dict_path}. " - "Please commit it to the repo and run the test again." - ) - - -SELECTORS = { - "NoiseConditionedSFNO": get_noise_conditioned_sfno_module_selector(), -} - - -@pytest.mark.parametrize( - "selector_name", - SELECTORS.keys(), -) -def test_module_backwards_compatibility(selector_name: str): - torch.manual_seed(0) +def get_noise_conditioned_sfno_module() -> tuple[ModuleSelector, Module]: img_shape = (9, 18) n_in_channels = 5 n_out_channels = 6 @@ -140,10 +144,121 @@ def test_module_backwards_compatibility(selector_name: str): timestep=timestep, all_labels=all_labels, ) - module = SELECTORS[selector_name].build( + selector = ModuleSelector( + type="NoiseConditionedSFNO", + config={ + "embed_dim": 8, + "noise_embed_dim": 4, + "noise_type": "isotropic", + "filter_type": "linear", + "use_mlp": True, + "num_layers": 4, + "operator_type": "dhconv", + "affine_norms": True, + "spectral_transform": "sht", + }, + ) + module = selector.build( n_in_channels=n_in_channels, n_out_channels=n_out_channels, dataset_info=dataset_info, ) + return selector, module + + +def load_state(selector_name: str) -> dict[str, torch.Tensor]: + state_dict_path = DATA_DIR / f"{selector_name}_state_dict.pt" + if not state_dict_path.exists(): + raise RuntimeError( + f"State dict for {selector_name} not found at {state_dict_path}. " + "Please make sure the checkpoint exists and is committed to the repo." + ) + return torch.load(state_dict_path) + + +def load_or_cache_state( + selector_name: str, module: Module, module_config: ModuleConfig | None = None +) -> dict[str, torch.Tensor]: + state_dict_path = DATA_DIR / f"{selector_name}_state_dict.pt" + if state_dict_path.exists(): + return torch.load(state_dict_path) + else: + state_dict = module.get_state() + torch.save(state_dict, state_dict_path) + raise AssertionError( + f"State dict for {selector_name} not found. " + f"Created a new one at {state_dict_path}. " + "Please commit it to the repo and run the test again." + ) + + +def load_or_cache_module_config( + selector_name: str, module_config: dict[str, Any] +) -> dict[str, Any]: + module_config_path = DATA_DIR / f"{selector_name}_module_config.yaml" + if module_config_path.exists(): + with open(module_config_path) as f: + data = yaml.safe_load(f) + return data + else: + with open(module_config_path, "w") as f: + yaml.safe_dump(module_config, f) + raise AssertionError( + f"Module config for {selector_name} not found. " + f"Created a new one at {module_config_path}. " + "Please commit it to the repo and run the test again." + ) + + +FROZEN_BUILDERS = { + "dbc2925_ncsfno": get_dbc2925_ncsfno_module, +} + + +@pytest.mark.parametrize( + "selector_name", + FROZEN_BUILDERS.keys(), +) +def test_frozen_module_backwards_compatibility(selector_name: str): + """ + Backwards compatibility for frozen releases from specific commits. + """ + set_seed(0) + _, module = FROZEN_BUILDERS[selector_name]() + loaded_state_dict = load_state(selector_name) + module.load_state(loaded_state_dict) + + +LATEST_BUILDERS = { + "NoiseConditionedSFNO": get_noise_conditioned_sfno_module, +} + + +@pytest.mark.parametrize( + "selector_name", + LATEST_BUILDERS.keys(), +) +def test_latest_module_backwards_compatibility(selector_name: str): + """ + Backwards compatibility for the latest module implementations. + + Should be kept up-to-date with the latest code changes. + """ + set_seed(0) + selector, module = LATEST_BUILDERS[selector_name]() loaded_state_dict = load_or_cache_state(selector_name, module) module.load_state(loaded_state_dict) + # check if config has new keys and fail so we update the checkpoint if it does + module_config = dataclasses.asdict(selector.module_config) + loaded_module_config = load_or_cache_module_config(selector_name, module_config) + new_keys = set(module_config.keys()).difference(loaded_module_config.keys()) + assert not new_keys, ( + f"New keys {new_keys} were added to the module config of {selector_name}. " + "If you want to ensure backwards compatibility of this new feature, " + "you must update the configuration for this module to use that feature, " + "then run this test to update the cached config and checkpoint, and " + "commit those files to the repo. If you do not want to ensure backwards " + "compatibility of this feature, you must still re-generate the checkpoint " + "to remove this error. In either case update the checkpoint " + "(and configuration) as its own isolated commit." + ) diff --git a/fme/core/registry/testdata/NoiseConditionedSFNO_module_config.yaml b/fme/core/registry/testdata/NoiseConditionedSFNO_module_config.yaml new file mode 100644 index 000000000..287c2a336 --- /dev/null +++ b/fme/core/registry/testdata/NoiseConditionedSFNO_module_config.yaml @@ -0,0 +1,34 @@ +activation_function: gelu +affine_norms: true +big_skip: true +checkpointing: 0 +complex_activation: real +complex_network: true +context_pos_embed_dim: 0 +data_grid: legendre-gauss +embed_dim: 8 +encoder_layers: 1 +factorization: null +filter_num_groups: 1 +filter_output: false +filter_residual: false +filter_type: linear +global_layer_norm: false +local_blocks: null +lora_alpha: null +lora_rank: 0 +mlp_ratio: 2.0 +noise_embed_dim: 4 +noise_type: isotropic +normalize_big_skip: false +num_layers: 4 +operator_type: dhconv +pos_embed: true +rank: 1.0 +residual_filter_factor: 1 +separable: false +spectral_layers: 1 +spectral_lora_alpha: null +spectral_lora_rank: 0 +spectral_transform: sht +use_mlp: true diff --git a/fme/core/registry/testdata/dbc2925_ncsfno_state_dict.pt b/fme/core/registry/testdata/dbc2925_ncsfno_state_dict.pt new file mode 100644 index 000000000..f05bfa735 Binary files /dev/null and b/fme/core/registry/testdata/dbc2925_ncsfno_state_dict.pt differ