Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion fme/core/registry/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def __post_init__(self):
)
self._instance = self.registry.get(self.type, self.config)

@property
def module_config(self) -> ModuleConfig:
return self._instance

@classmethod
def register(
cls, type_name: str
Expand All @@ -161,7 +165,7 @@ def build(
n_in_channels: int,
n_out_channels: int,
dataset_info: DatasetInfo,
) -> nn.Module:
) -> Module:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

"""
Build a nn.Module given information about the input and output channels
and the dataset.
Expand Down
171 changes: 143 additions & 28 deletions fme/core/registry/test_module_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
import datetime
import pathlib
from collections.abc import Iterable
from typing import Any

import dacite
import pytest
import torch
import yaml

import fme
from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates
from fme.core.dataset_info import DatasetInfo
from fme.core.labels import LabelEncoding
from fme.core.rand import set_seed
from fme.core.registry.module import Module

from .module import CONDITIONAL_BUILDERS, ModuleConfig, ModuleSelector
Expand Down Expand Up @@ -79,8 +82,27 @@ def test_module_selector_raises_with_bad_config():
ModuleSelector(type="mock", config={"non_existent_key": 1})


def get_noise_conditioned_sfno_module_selector() -> ModuleSelector:
return ModuleSelector(
def get_dbc2925_ncsfno_module() -> tuple[ModuleSelector, Module]:
img_shape = (9, 18)
n_in_channels = 5
n_out_channels = 6
all_labels = {"a", "b"}
timestep = datetime.timedelta(hours=6)
device = fme.get_device()
horizontal_coordinate = LatLonCoordinates(
lat=torch.zeros(img_shape[0], device=device),
lon=torch.zeros(img_shape[1], device=device),
)
vertical_coordinate = HybridSigmaPressureCoordinate(
ak=torch.arange(7, device=device), bk=torch.arange(7, device=device)
)
dataset_info = DatasetInfo(
horizontal_coordinates=horizontal_coordinate,
vertical_coordinate=vertical_coordinate,
timestep=timestep,
all_labels=all_labels,
)
Comment on lines +86 to +104
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a new / reusing an existing shared helper to create the DatasetInfo that both get_dbc2925_ncsfno_module and get_noise_conditioned_sfno_module can reuse.

I think it makes sense to repeat the select = ModuleSelector( blocks, even if they're currently identical.

selector = ModuleSelector(
type="NoiseConditionedSFNO",
config={
"embed_dim": 8,
Expand All @@ -94,33 +116,15 @@ def get_noise_conditioned_sfno_module_selector() -> ModuleSelector:
"spectral_transform": "sht",
},
)
module = selector.build(
n_in_channels=n_in_channels,
n_out_channels=n_out_channels,
dataset_info=dataset_info,
)
return selector, module


def load_or_cache_state(selector_name: str, module: Module) -> dict[str, torch.Tensor]:
state_dict_path = DATA_DIR / f"{selector_name}_state_dict.pt"
if state_dict_path.exists():
return torch.load(state_dict_path)
else:
state_dict = module.get_state()
torch.save(state_dict, state_dict_path)
raise RuntimeError(
f"State dict for {selector_name} not found. "
f"Created a new one at {state_dict_path}. "
"Please commit it to the repo and run the test again."
)


SELECTORS = {
"NoiseConditionedSFNO": get_noise_conditioned_sfno_module_selector(),
}


@pytest.mark.parametrize(
"selector_name",
SELECTORS.keys(),
)
def test_module_backwards_compatibility(selector_name: str):
torch.manual_seed(0)
def get_noise_conditioned_sfno_module() -> tuple[ModuleSelector, Module]:
img_shape = (9, 18)
n_in_channels = 5
n_out_channels = 6
Expand All @@ -140,10 +144,121 @@ def test_module_backwards_compatibility(selector_name: str):
timestep=timestep,
all_labels=all_labels,
)
module = SELECTORS[selector_name].build(
selector = ModuleSelector(
type="NoiseConditionedSFNO",
config={
"embed_dim": 8,
"noise_embed_dim": 4,
"noise_type": "isotropic",
"filter_type": "linear",
"use_mlp": True,
"num_layers": 4,
"operator_type": "dhconv",
"affine_norms": True,
"spectral_transform": "sht",
},
)
module = selector.build(
n_in_channels=n_in_channels,
n_out_channels=n_out_channels,
dataset_info=dataset_info,
)
return selector, module


def load_state(selector_name: str) -> dict[str, torch.Tensor]:
state_dict_path = DATA_DIR / f"{selector_name}_state_dict.pt"
if not state_dict_path.exists():
raise RuntimeError(
f"State dict for {selector_name} not found at {state_dict_path}. "
"Please make sure the checkpoint exists and is committed to the repo."
)
return torch.load(state_dict_path)


def load_or_cache_state(
selector_name: str, module: Module, module_config: ModuleConfig | None = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module_config argument to load_or_cache_state appears to be unused.

) -> dict[str, torch.Tensor]:
state_dict_path = DATA_DIR / f"{selector_name}_state_dict.pt"
if state_dict_path.exists():
return torch.load(state_dict_path)
else:
state_dict = module.get_state()
torch.save(state_dict, state_dict_path)
raise AssertionError(
f"State dict for {selector_name} not found. "
f"Created a new one at {state_dict_path}. "
"Please commit it to the repo and run the test again."
)


def load_or_cache_module_config(
selector_name: str, module_config: dict[str, Any]
) -> dict[str, Any]:
module_config_path = DATA_DIR / f"{selector_name}_module_config.yaml"
if module_config_path.exists():
with open(module_config_path) as f:
data = yaml.safe_load(f)
return data
else:
with open(module_config_path, "w") as f:
yaml.safe_dump(module_config, f)
raise AssertionError(
f"Module config for {selector_name} not found. "
f"Created a new one at {module_config_path}. "
"Please commit it to the repo and run the test again."
)


FROZEN_BUILDERS = {
"dbc2925_ncsfno": get_dbc2925_ncsfno_module,
}


@pytest.mark.parametrize(
"selector_name",
FROZEN_BUILDERS.keys(),
)
def test_frozen_module_backwards_compatibility(selector_name: str):
"""
Backwards compatibility for frozen releases from specific commits.
"""
set_seed(0)
_, module = FROZEN_BUILDERS[selector_name]()
loaded_state_dict = load_state(selector_name)
module.load_state(loaded_state_dict)
Comment on lines 222 to 229
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still a bit concerned that if we place no limits on new keys then we could inadvertently introduce changes in behavior that lead to regressions in inference skill with FROZEN_BUILDERS checkpoints. I wonder if we should also raise an error here with a message along the lines of:

"New module parameters {new_keys} found that were not present in the 
"frozen" checkpoint. New module parameters may be added but should 
be enabled by adding a new config parameter that, by default, does not 
add the new parameters when building the module."

We could also save and reload the module config dict together with the artifact to verify that the config builds the same architecture. Of course there are a million other ways to change the module code that could lead to inference regressions, but I don't see why we should allow arbitrary new parameters that weren't present when the checkpoint was saved if there is a way to avoid it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That type of thing is supposed to be covered by the "produces the same result" test(s). If you're concerned about new parameters not affecting an initial prediction but affecting later ones after the first gradient update, I should add a second stage to those tests that does a second step when testing for identicality.

I see though now what you're saying, I've not been understanding it. In practice, what I have here won't catch any of the cases I care about updating the regression tests for, because we always add them in a way that sets the weights to None (which doesn't get registered in the state dict). Really what we need is to remember to update/write a new test when we add features that define new weights.

What I actually want to do is test that the config has no new keys, and force the user to build a new latest checkpoint when new config keys are added, saving the asdict'd config with the checkpoint. I'll see about adding that, and also adding what you suggested about making sure the config builds the same architecture.

Copy link
Contributor Author

@mcgibbon mcgibbon Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at it more, and because we load model states with strict=True (the default), this will already error if the built model keys differ from the checkpoint. I can see how my other (wrong) test implied this wasn't the case, though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point. Maybe we should add a test that confirms this behavior for Module.load_state (and maybe also other parts of Module). But that's a preexisting issue.



LATEST_BUILDERS = {
"NoiseConditionedSFNO": get_noise_conditioned_sfno_module,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should NoiseConditionedSFNO_state_dict.pt also be updated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was already a "latest" version.

}


@pytest.mark.parametrize(
"selector_name",
LATEST_BUILDERS.keys(),
)
def test_latest_module_backwards_compatibility(selector_name: str):
"""
Backwards compatibility for the latest module implementations.

Should be kept up-to-date with the latest code changes.
"""
set_seed(0)
selector, module = LATEST_BUILDERS[selector_name]()
loaded_state_dict = load_or_cache_state(selector_name, module)
module.load_state(loaded_state_dict)
# check if config has new keys and fail so we update the checkpoint if it does
module_config = dataclasses.asdict(selector.module_config)
loaded_module_config = load_or_cache_module_config(selector_name, module_config)
new_keys = set(module_config.keys()).difference(loaded_module_config.keys())
assert not new_keys, (
f"New keys {new_keys} were added to the module config of {selector_name}. "
"If you want to ensure backwards compatibility of this new feature, "
"you must update the configuration for this module to use that feature, "
"then run this test to update the cached config and checkpoint, and "
"commit those files to the repo. If you do not want to ensure backwards "
"compatibility of this feature, you must still re-generate the checkpoint "
"to remove this error. In either case update the checkpoint "
"(and configuration) as its own isolated commit."
)
34 changes: 34 additions & 0 deletions fme/core/registry/testdata/NoiseConditionedSFNO_module_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
activation_function: gelu
affine_norms: true
big_skip: true
checkpointing: 0
complex_activation: real
complex_network: true
context_pos_embed_dim: 0
data_grid: legendre-gauss
embed_dim: 8
encoder_layers: 1
factorization: null
filter_num_groups: 1
filter_output: false
filter_residual: false
filter_type: linear
global_layer_norm: false
local_blocks: null
lora_alpha: null
lora_rank: 0
mlp_ratio: 2.0
noise_embed_dim: 4
noise_type: isotropic
normalize_big_skip: false
num_layers: 4
operator_type: dhconv
pos_embed: true
rank: 1.0
residual_filter_factor: 1
separable: false
spectral_layers: 1
spectral_lora_alpha: null
spectral_lora_rank: 0
spectral_transform: sht
use_mlp: true
Binary file not shown.