Conversation
| ) -> StepLossABC: | ||
| if self.n_steps == 0 or self.weight == 0.0: | ||
| return NullLossContributions() | ||
| return NullLossContributions(loss_obj) |
There was a problem hiding this comment.
This preserves the existing behavior where we used a component stepper's effective_loss_scaling to compute mse_fractional_components metrics even if the stepper had no loss contribution in coupled training.
mcgibbon
left a comment
There was a problem hiding this comment.
Just some nits (nits are optional), I don't need to re-review them. LGTM
fme/coupled/test_loss.py
Outdated
|
|
||
| @property | ||
| def effective_loss_scaling(self): | ||
| raise NotImplementedError |
There was a problem hiding this comment.
| raise NotImplementedError | |
| raise NotImplementedError() |
fme/coupled/test_loss.py
Outdated
| atmos_loss_config = LossContributionsConfig() | ||
| atmosphere_loss = atmos_loss_config.build( | ||
| loss_obj=lambda *_, **__: torch.tensor(5.25), | ||
| loss_obj=Mock(spec=StepLoss, side_effect=lambda *_, **__: torch.tensor(5.25)), |
There was a problem hiding this comment.
nit: The three lines changed use three different ways to specify the loss side-effect - via mae_loss, via a lambda function returning a constant, and via a return_value instead of a side_effect. You could consider using return_value for this one to reduce that down to 2 ways, at least.
| n_samples=3, | ||
| ) | ||
| output = coupler.train_on_batch( | ||
| train_stepper_config = CoupledTrainStepperConfig( |
There was a problem hiding this comment.
nit: Avoid the 3x copy-paste of this process by making a get_train_stepper_and_batch helper that does it and calls get_stepper_and_batch internally.
There was a problem hiding this comment.
Good idea, but I'll defer this cleanup to #814 since the way in which the train stepper is built is going to change.
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class CoupledTrainStepperConfig: |
There was a problem hiding this comment.
Do you have an example of the updated training config committed somewhere I could check out? It would be nice to have a baseline config for coupled training, if so I could see the changes to the baseline in this PR.
Update: Ah I see test_train.py mostly fits this purpose, good. Still, could be nice to have a baseline in the future.
There was a problem hiding this comment.
Agreed, I will work on a new PR to add the baseline.
fme/coupled/test_train.py
Outdated
| loss_contributions: | ||
| n_steps: {loss_atmos_n_steps} | ||
| stepper: | ||
| loss: |
There was a problem hiding this comment.
Question: Why is loss: type: MSE in both the atmosphere: stepper: and in the train_stepper: atmosphere:? I am guessing because we haven't updated the ACE configs yet and it's required in the config, in which case that's fine, but I though I should ask to be sure.
There was a problem hiding this comment.
That's right, although including it in the yaml here isn't strictly necessary since there is a default value on StepperConfig. I'll remove it so it's a bit clearer here.
fme/coupled/train/train.py
Outdated
| atmosphere_normalize=stepper.atmosphere.normalizer.normalize, | ||
| ocean_loss_scaling=stepper.ocean.effective_loss_scaling, | ||
| atmosphere_loss_scaling=stepper.atmosphere.effective_loss_scaling, | ||
| ocean_loss_scaling=stepper.effective_loss_scaling.ocean, |
There was a problem hiding this comment.
nit: pass loss_scaling: stepper.loss_scaling instead of two arguments containing the parts
Adds
train_stepper: CoupledTrainStepperConfigto the coupled training config, which configures and builds aCoupledTrainStepperimplementingTrainStepperABC.WARNING: This is a breaking change for existing coupled training configs.
Changes:
Component stepper
loss: StepLossConfigandloss_contributions: LossContributionsConfigare now configured via theocean: ComponentTrainingConfigandatmosphere: ComponentTrainingConfigattributes ofCoupledTrainStepperConfig.CoupledStepperno longer implementsTrainStepperABC.Removed public
loss_objandeffective_loss_scalingproperties fromfme.ace.stepper.Stepperand added a new public methodbuild_loss.Tests added