Skip to content

Move parameter_init to train stepper configs#814

Open
jpdunc23 wants to merge 22 commits intomainfrom
refactor/move-param-init-to-train-configs
Open

Move parameter_init to train stepper configs#814
jpdunc23 wants to merge 22 commits intomainfrom
refactor/move-param-init-to-train-configs

Conversation

@jpdunc23
Copy link
Member

@jpdunc23 jpdunc23 commented Feb 10, 2026

Continuation of the separation of training-specific concerns from inference stepper configs, building on #809.

This PR makes backwards-incompatible changes affecting fme.coupled training and fine-tuning configs:

  • Training (init from separate uncoupled component checkpoints): Each component's parameter_init: ParameterInitializationConfig is now configured using its respective ComponentTrainingConfig on the CoupledTrainStepperConfig.
  • Fine-tuning (init from a coupled stepper checkpoint): Moved parameter_init: CoupledParameterInitConfig out of CoupledStepperConfig and into CoupledTrainStepperConfig.

Existing fme.ace training YAML configs will continue to work without changes, with parameter_init now transferred to TrainStepperConfig via StepperConfig.get_train_stepper_config(). This backwards compatibility will be removed in a future PR.

Changes:

  • CoupledTrainStepperConfig now owns both CoupledParameterInitConfig and the per-component ParameterInitializationConfig (via ComponentTrainingConfig).

  • StepperConfig.get_stepper() signature changed to accept an optional ParameterInitializer instead of a boolean flag

  • TrainStepperConfig now owns parameter_init and builds both the initializer and the underlying Stepper

  • Tests added

@jpdunc23 jpdunc23 changed the base branch from main to refactor/coupled-train-stepper February 10, 2026 20:52
@jpdunc23 jpdunc23 mentioned this pull request Feb 10, 2026
1 task
Base automatically changed from refactor/coupled-train-stepper to main February 11, 2026 07:56
@jpdunc23 jpdunc23 marked this pull request as ready for review February 12, 2026 20:36
@jpdunc23 jpdunc23 requested a review from mcgibbon February 13, 2026 21:54
"""
Configuration for a stepper.

The following fields are training concerns transferred to TrainStepperConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Probably don't do this and just wait, but a clearer way to communicate this is to define a class _NewStepperConfig() used internally, move get_stepper to that config, and replace this old config with the new one in your next PR. In other words, you could fully implement the new sub-configs while keeping the current YAML config layer, instead of only implementing the new Training one. But what's here is fine if you prefer refactoring it in that way in the next PR.

Copy link
Member Author

@jpdunc23 jpdunc23 Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this changes anything for you, but my plan for a (hopefully small) future PR is to remove these attributes from StepperConfig and add the train_stepper: TrainStepperConfig attribute to TrainConfig.

This planned PR would cause backwards-incompatible changes to fme.ace training, which we can communicate on Monday's technical sync so that folks have plenty of notice.

Or do you think it would be better to maintain training backwards-compatibility going forward? I personally feel we should just rip off the bandaid and switch to the new train_stepper config style.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm agreed on the destination, the plan sounds good to me.

Another way to say my comment is, treating the existing StepperConfig like a facade with no real functionality beyond "construct the new classes" would give you the freedom to write all the features of those classes now in the way they will eventually be defined. Then you could have a final PR that amounts to "delete the old yaml config and its translation layer, move the new ones in place as public API" with no real feature changes in that final PR.

Right now, your implementing what will be the final "StepperConfig" (what I refer to in the comment as a potentially temporarily named "_NewStepperConfig") is being blocked on the breaking yaml changes, because of the self-imposed requirement that the new features be implemented directly onto StepperConfig instead of temporarily onto a new class that later replaces StepperConfig. As a result you have this strange construction series you need to document for this temporary intermediate period, for example.

There's no action needed here, we can keep going on the current route, but this kind of pattern happens all the time when doing refactors and I thought it important to bring up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks for clarifying. What I misunderstood in your first comment was I thought you were suggesting we add _NewStepperConfig in a future PR. For future refactors will definitely try to remember to use the pattern you're suggesting here which I agree would have been much cleaner.

return TimeLengthSchedule.from_constant(self.train_n_forward_steps)

def get_train_stepper(self, stepper: Stepper) -> "TrainStepper":
def get_parameter_initializer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion (optional): Make this function private. I know it wasn't private in the previous config (which is why this is optional), but it seems it was only used privately and probably should be private.

loss=StepLossConfig(type="MSE"),
)
stepper = train_stepper_config.get_train_stepper(unittest.mock.Mock())
stepper = TrainStepper(stepper=unittest.mock.Mock(), config=train_stepper_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Why is this kind of refactor needed? Can't the train stepper config operate with a default parameter init config?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to your comment https://github.com/ai2cm/ace/pull/814/changes#r2823350718. Since TrainStepperConfig.get_train_stepper now builds the Stepper instance, I'm avoiding that by just directly initializing TrainStepper here. Otherwise I would need to pass a more complicated mock StepperConfig to get_train_stepper.

load_weights_and_history=load_weights_and_history_fn
)

def get_train_stepper(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(No action for this PR)

I'm realizing the dependencies are a bit back-and-forth here, but after diving into it fairly deeply I think it can't be easily avoided. It was really nice before to have get_train_stepper(stepper: Stepper) take in an already-built object, otherwise the builder here has to build two objects which is more complex. It means we've tightly coupled the build of these two decoupleable things.

However, the reason it's coupled is that we need to freeze weights before wrapping in the distributed module wrapper, meaning in the current flow it has to happen during the stepper init.

Let's leave this on the backburner for the moment, but perhaps the TrainStepper should be transforming the Stepper's modules during its initialization. The distributed module wrapper's responsibility is also training-specific - it's not technically needed for inference, the purpose is to properly calculate gradients in a batch-distributed context. It could be injected by the train stepper at the same time as weight freezing. That way the StepperConfig's get_stepper wouldn't need to take in a fairly heavy object like ParameterInitializationConfig.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, the reason it's coupled is that we need to freeze weights before wrapping in the distributed module wrapper, meaning in the current flow it has to happen during the stepper init.

Exactly, this forced my hand in doing it this way.

Let's leave this on the backburner for the moment, but perhaps the TrainStepper should be transforming the Stepper's modules during its initialization. The distributed module wrapper's responsibility is also training-specific - it's not technically needed for inference

That's a good idea. This might be worth tackling as the next step in this sequence of PRs.

loss=StepLossConfig(type="MSE"),
)
stepper = train_stepper_config.get_train_stepper(unittest.mock.Mock())
stepper = TrainStepper(stepper=unittest.mock.Mock(), config=train_stepper_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion (optional): write a _get_stepper(train_stepper_config) or similar helper function to avoid needing to refactor this in as many places in the future.

Copy link
Contributor

@mcgibbon mcgibbon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments