Conversation
| """ | ||
| Configuration for a stepper. | ||
|
|
||
| The following fields are training concerns transferred to TrainStepperConfig |
There was a problem hiding this comment.
nit: Probably don't do this and just wait, but a clearer way to communicate this is to define a class _NewStepperConfig() used internally, move get_stepper to that config, and replace this old config with the new one in your next PR. In other words, you could fully implement the new sub-configs while keeping the current YAML config layer, instead of only implementing the new Training one. But what's here is fine if you prefer refactoring it in that way in the next PR.
There was a problem hiding this comment.
Not sure if this changes anything for you, but my plan for a (hopefully small) future PR is to remove these attributes from StepperConfig and add the train_stepper: TrainStepperConfig attribute to TrainConfig.
This planned PR would cause backwards-incompatible changes to fme.ace training, which we can communicate on Monday's technical sync so that folks have plenty of notice.
Or do you think it would be better to maintain training backwards-compatibility going forward? I personally feel we should just rip off the bandaid and switch to the new train_stepper config style.
There was a problem hiding this comment.
I'm agreed on the destination, the plan sounds good to me.
Another way to say my comment is, treating the existing StepperConfig like a facade with no real functionality beyond "construct the new classes" would give you the freedom to write all the features of those classes now in the way they will eventually be defined. Then you could have a final PR that amounts to "delete the old yaml config and its translation layer, move the new ones in place as public API" with no real feature changes in that final PR.
Right now, your implementing what will be the final "StepperConfig" (what I refer to in the comment as a potentially temporarily named "_NewStepperConfig") is being blocked on the breaking yaml changes, because of the self-imposed requirement that the new features be implemented directly onto StepperConfig instead of temporarily onto a new class that later replaces StepperConfig. As a result you have this strange construction series you need to document for this temporary intermediate period, for example.
There's no action needed here, we can keep going on the current route, but this kind of pattern happens all the time when doing refactors and I thought it important to bring up.
There was a problem hiding this comment.
OK, thanks for clarifying. What I misunderstood in your first comment was I thought you were suggesting we add _NewStepperConfig in a future PR. For future refactors will definitely try to remember to use the pattern you're suggesting here which I agree would have been much cleaner.
fme/ace/stepper/single_module.py
Outdated
| return TimeLengthSchedule.from_constant(self.train_n_forward_steps) | ||
|
|
||
| def get_train_stepper(self, stepper: Stepper) -> "TrainStepper": | ||
| def get_parameter_initializer( |
There was a problem hiding this comment.
Suggestion (optional): Make this function private. I know it wasn't private in the previous config (which is why this is optional), but it seems it was only used privately and probably should be private.
| loss=StepLossConfig(type="MSE"), | ||
| ) | ||
| stepper = train_stepper_config.get_train_stepper(unittest.mock.Mock()) | ||
| stepper = TrainStepper(stepper=unittest.mock.Mock(), config=train_stepper_config) |
There was a problem hiding this comment.
Question: Why is this kind of refactor needed? Can't the train stepper config operate with a default parameter init config?
There was a problem hiding this comment.
This is related to your comment https://github.com/ai2cm/ace/pull/814/changes#r2823350718. Since TrainStepperConfig.get_train_stepper now builds the Stepper instance, I'm avoiding that by just directly initializing TrainStepper here. Otherwise I would need to pass a more complicated mock StepperConfig to get_train_stepper.
| load_weights_and_history=load_weights_and_history_fn | ||
| ) | ||
|
|
||
| def get_train_stepper( |
There was a problem hiding this comment.
(No action for this PR)
I'm realizing the dependencies are a bit back-and-forth here, but after diving into it fairly deeply I think it can't be easily avoided. It was really nice before to have get_train_stepper(stepper: Stepper) take in an already-built object, otherwise the builder here has to build two objects which is more complex. It means we've tightly coupled the build of these two decoupleable things.
However, the reason it's coupled is that we need to freeze weights before wrapping in the distributed module wrapper, meaning in the current flow it has to happen during the stepper init.
Let's leave this on the backburner for the moment, but perhaps the TrainStepper should be transforming the Stepper's modules during its initialization. The distributed module wrapper's responsibility is also training-specific - it's not technically needed for inference, the purpose is to properly calculate gradients in a batch-distributed context. It could be injected by the train stepper at the same time as weight freezing. That way the StepperConfig's get_stepper wouldn't need to take in a fairly heavy object like ParameterInitializationConfig.
There was a problem hiding this comment.
However, the reason it's coupled is that we need to freeze weights before wrapping in the distributed module wrapper, meaning in the current flow it has to happen during the stepper init.
Exactly, this forced my hand in doing it this way.
Let's leave this on the backburner for the moment, but perhaps the TrainStepper should be transforming the Stepper's modules during its initialization. The distributed module wrapper's responsibility is also training-specific - it's not technically needed for inference
That's a good idea. This might be worth tackling as the next step in this sequence of PRs.
| loss=StepLossConfig(type="MSE"), | ||
| ) | ||
| stepper = train_stepper_config.get_train_stepper(unittest.mock.Mock()) | ||
| stepper = TrainStepper(stepper=unittest.mock.Mock(), config=train_stepper_config) |
There was a problem hiding this comment.
Suggestion (optional): write a _get_stepper(train_stepper_config) or similar helper function to avoid needing to refactor this in as many places in the future.
Continuation of the separation of training-specific concerns from inference stepper configs, building on #809.
This PR makes backwards-incompatible changes affecting
fme.coupledtraining and fine-tuning configs:parameter_init: ParameterInitializationConfigis now configured using its respectiveComponentTrainingConfigon theCoupledTrainStepperConfig.parameter_init: CoupledParameterInitConfigout ofCoupledStepperConfigand intoCoupledTrainStepperConfig.Existing
fme.acetraining YAML configs will continue to work without changes, withparameter_initnow transferred toTrainStepperConfigviaStepperConfig.get_train_stepper_config(). This backwards compatibility will be removed in a future PR.Changes:
CoupledTrainStepperConfignow owns bothCoupledParameterInitConfigand the per-componentParameterInitializationConfig(viaComponentTrainingConfig).StepperConfig.get_stepper()signature changed to accept an optionalParameterInitializerinstead of a boolean flagTrainStepperConfignow ownsparameter_initand builds both the initializer and the underlyingStepperTests added