Skip to content

Separate regression tests for frozen and latest checkpoints#830

Open
mcgibbon wants to merge 14 commits intomainfrom
feature/legacy_regression_tests
Open

Separate regression tests for frozen and latest checkpoints#830
mcgibbon wants to merge 14 commits intomainfrom
feature/legacy_regression_tests

Conversation

@mcgibbon
Copy link
Contributor

@mcgibbon mcgibbon commented Feb 12, 2026

This PR defines separate regression tests for "frozen" checkpoints from specific commits (which shouldn't be updated) and "latest" checkpoints ensuring ongoing backwards compatibility (which should be updated).

@mcgibbon mcgibbon marked this pull request as ready for review February 12, 2026 20:49
@mcgibbon mcgibbon requested a review from jpdunc23 February 12, 2026 20:49
Copy link
Member

@jpdunc23 jpdunc23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still concerned about silently allowing new parameters with frozen checkpoints, though it maybe warrants a discussion at a tech sync and I'm willing to punt for now if you prefer.

Other than that, I have one blocking comment about updating the existing artifact.

n_out_channels: int,
dataset_info: DatasetInfo,
) -> nn.Module:
) -> Module:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch



LATEST_BUILDERS = {
"NoiseConditionedSFNO": get_noise_conditioned_sfno_module,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should NoiseConditionedSFNO_state_dict.pt also be updated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was already a "latest" version.

Comment on lines 158 to 165
def test_frozen_module_backwards_compatibility(selector_name: str):
"""
Backwards compatibility for frozen releases from specific commits.
"""
set_seed(0)
module = FROZEN_BUILDERS[selector_name]()
loaded_state_dict = load_state(selector_name)
module.load_state(loaded_state_dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still a bit concerned that if we place no limits on new keys then we could inadvertently introduce changes in behavior that lead to regressions in inference skill with FROZEN_BUILDERS checkpoints. I wonder if we should also raise an error here with a message along the lines of:

"New module parameters {new_keys} found that were not present in the 
"frozen" checkpoint. New module parameters may be added but should 
be enabled by adding a new config parameter that, by default, does not 
add the new parameters when building the module."

We could also save and reload the module config dict together with the artifact to verify that the config builds the same architecture. Of course there are a million other ways to change the module code that could lead to inference regressions, but I don't see why we should allow arbitrary new parameters that weren't present when the checkpoint was saved if there is a way to avoid it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That type of thing is supposed to be covered by the "produces the same result" test(s). If you're concerned about new parameters not affecting an initial prediction but affecting later ones after the first gradient update, I should add a second stage to those tests that does a second step when testing for identicality.

I see though now what you're saying, I've not been understanding it. In practice, what I have here won't catch any of the cases I care about updating the regression tests for, because we always add them in a way that sets the weights to None (which doesn't get registered in the state dict). Really what we need is to remember to update/write a new test when we add features that define new weights.

What I actually want to do is test that the config has no new keys, and force the user to build a new latest checkpoint when new config keys are added, saving the asdict'd config with the checkpoint. I'll see about adding that, and also adding what you suggested about making sure the config builds the same architecture.

Copy link
Contributor Author

@mcgibbon mcgibbon Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at it more, and because we load model states with strict=True (the default), this will already error if the built model keys differ from the checkpoint. I can see how my other (wrong) test implied this wasn't the case, though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point. Maybe we should add a test that confirms this behavior for Module.load_state (and maybe also other parts of Module). But that's a preexisting issue.

Copy link
Member

@jpdunc23 jpdunc23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One unused argument, otherwise just minor comments and LGTM.



def load_or_cache_state(
selector_name: str, module: Module, module_config: ModuleConfig | None = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module_config argument to load_or_cache_state appears to be unused.

Comment on lines +86 to +104
img_shape = (9, 18)
n_in_channels = 5
n_out_channels = 6
all_labels = {"a", "b"}
timestep = datetime.timedelta(hours=6)
device = fme.get_device()
horizontal_coordinate = LatLonCoordinates(
lat=torch.zeros(img_shape[0], device=device),
lon=torch.zeros(img_shape[1], device=device),
)
vertical_coordinate = HybridSigmaPressureCoordinate(
ak=torch.arange(7, device=device), bk=torch.arange(7, device=device)
)
dataset_info = DatasetInfo(
horizontal_coordinates=horizontal_coordinate,
vertical_coordinate=vertical_coordinate,
timestep=timestep,
all_labels=all_labels,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a new / reusing an existing shared helper to create the DatasetInfo that both get_dbc2925_ncsfno_module and get_noise_conditioned_sfno_module can reuse.

I think it makes sense to repeat the select = ModuleSelector( blocks, even if they're currently identical.

Comment on lines 158 to 165
def test_frozen_module_backwards_compatibility(selector_name: str):
"""
Backwards compatibility for frozen releases from specific commits.
"""
set_seed(0)
module = FROZEN_BUILDERS[selector_name]()
loaded_state_dict = load_state(selector_name)
module.load_state(loaded_state_dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point. Maybe we should add a test that confirms this behavior for Module.load_state (and maybe also other parts of Module). But that's a preexisting issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments