Skip to content

Allow multiple static inputs#822

Open
AnnaKwa wants to merge 9 commits intomainfrom
feature/allow-multiple-static-inputs-v2
Open

Allow multiple static inputs#822
AnnaKwa wants to merge 9 commits intomainfrom
feature/allow-multiple-static-inputs-v2

Conversation

@AnnaKwa
Copy link
Contributor

@AnnaKwa AnnaKwa commented Feb 11, 2026

This refactors the downscaling code to use the full StaticInputs class in generate calls.

  • DiffusionModel._get_input_from_coarse concats all StaticInputs.fields.data as additional input channels

  • breaking change: TrainConfig has a new field static_inputs which should be a mapping of {var_name: path_to_dataset}. This should be used to provide the static inputs fields- they will no longer be loaded automatically from the fine dataset. In the future, use_fine_topography in the config should be deprecated since the static_inputs field is the source of this information.

  • GriddedData, PairedGriddedData are updated to have a StaticInputs attribute instead of the single Topography

  • references to variables named topography are updated to the more general static_input

  • train, predict, inference, evaluator entrypoint code and their tests are updated to pass the StaticInputs to generate calls

  • last three commits are minor updates to use the correct number of input channels given the number of static inputs, add the static inputs to the from_state method, and update tests for these changes

(Exact diff of v1 branch but used cursor to organize the commits to be more reviewable)

AnnaKwa and others added 9 commits February 11, 2026 12:17
This property was unused and had a type inconsistency (returning
torch.tensor([]) instead of list[torch.Tensor] for the empty case).

Co-authored-by: Cursor <cursoragent@cursor.com>
Add a `static_inputs: dict[str, str] | None` field to TrainerConfig
that maps variable names to file paths. The build method now constructs
StaticInputs from this config rather than inferring topography from
the fine training data path.

Co-authored-by: Cursor <cursoragent@cursor.com>
…nputs

- Rename `Topography` class to `StaticInput` and export it publicly
- Update `GriddedData` and `PairedGriddedData` to store `static_inputs: StaticInputs`
  instead of `topography: Topography`
- Update `_get_input_from_coarse` in DiffusionModel to accept StaticInputs
  and loop over all fields rather than a single topography
- Rename `build_topography` to `build_static_inputs` in DataLoaderConfig
- Update `PairedDataLoaderConfig.build` to pass StaticInputs
- Remove unused `get_topography` from CheckpointModelConfig

Co-authored-by: Cursor <cursoragent@cursor.com>
Rename topography -> static_inputs throughout the codebase:
- train.py: training and validation loop variables
- predict.py: downscaler and event downscaler
- evaluator.py: evaluator and event evaluator
- inference/: downscaler, output config, work items
- _deterministic_models.py: model methods
- predictors/cascade.py: cascade predictor and config
- predictors/composite.py: patch predictor

Co-authored-by: Cursor <cursoragent@cursor.com>
Update all test files to use StaticInput/StaticInputs instead of Topography:
- test_topography.py: rename Topography -> StaticInput in test cases
- test_config.py: pass StaticInputs to DataLoaderConfig.build
- test_patching.py: use StaticInputs wrapper for patching tests
- test_models.py: update model train/generate tests
- test_predict.py: update predictor integration tests
- test_train.py: add static_inputs config to test setup
- test_train_config.yaml: add static_inputs field
- test_inference.py: update inference/downscaler tests
- test_output.py: pass StaticInputs to output config build
- test_cascade.py: update cascade predictor tests
- test_composite.py: update composite predictor tests

Co-authored-by: Cursor <cursoragent@cursor.com>
Copy link
Collaborator

@frodre frodre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @AnnaKwa, overall pretty good but still some minor suggestions and light cleanup to do before it's ready to merge.

@dataclasses.dataclass
class DiffusionModelConfig:
"""
f"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

necessary?

coarse_extent[0] * downscale_factor, coarse_extent[1] * downscale_factor
),
paired_batch_data.fine.latlon_coordinates[0],
static_inputs = StaticInputs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider a helper function for creating static inputs in tests where there are multiple instances.



def test_StaticInputs_generate_from_patches():
def testStaticInputs_generate_from_patches():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_S

Suggested change
def testStaticInputs_generate_from_patches():
def test_StaticInputs_generate_from_patches():



def test_StaticInputs_serialize():
def testStaticInputs_serialize():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def testStaticInputs_serialize():
def test_StaticInputs_serialize():

else:
# Join the normalized topography to the input (see dataset for details)
topo = topography.data.unsqueeze(self._channel_axis)
topo = static_inputs.fields[0].data.unsqueeze(self._channel_axis)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are going to remove _deterministic_models, right? Only asking because the use of topography only as item[0] maybe is a bit confusing here. Ignore if this is all going to disappear.

predict_residual: bool
use_fine_topography: bool = False
use_amp_bf16: bool = False
static_inputs: dict[str, str] | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this field get used? There's also a static inputs on the TrainConfig that gets used to build and is fed into the build of this class.

raise ValueError(
"Topography shape must be evenly divisible by data shape. "
f"Got topography {self.topography.shape} and data {self.shape}"
f"Got topography {self.static_inputs.shape} and data {self.shape}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error still references Topography. Does this need to be updated? Could be internal to the StaticInputs so we don't have to reference attribute specifics outside of their source.



@pytest.mark.parametrize("static_inputs_on_model", [True, False])
@pytest.mark.parametrize("static_inputs_on_model", [True])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming the False is removed because we will always expect static inputs on the model? No need for parametrization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments