Sophiex/dev/pretrained frozen teacher#1824
Conversation
The fix now: 1. FrozenTeacher inspects the teacher model's actual latent_heads attribute to determine what postprocessing is needed 2. Sets up JEPA/DINO/iBOT postprocessing based on what heads exist (using identity transform for all, with warnings for DINO/iBOT since full centering isn't supported for frozen teachers) 3. Tests updated to use models with latent_heads attributes
Summary of Changes
Key insight from your feedback: The frozen teacher may have been pre-trained with any method
(forecasting, MAE, etc.) and doesn't need to have SSL latent heads. We should:
1. Use the student's training config to know which SSL losses are needed
2. Add identity heads (LatentPredictionHeadIdentity) to the teacher if they don't exist
3. Use identity postprocessing (JEPATargetProcessing) for all SSL losses
Changes Made
src/weathergen/train/target_and_aux_ssl_teacher.py:
- Added import for LatentPredictionHeadIdentity
- Rewrote FrozenTeacher.__init__ to:
- Accept training_cfg (the student's config) to determine required SSL heads
- Call _get_required_ssl_heads() to extract loss names from config
- Call _ensure_identity_heads() to add missing heads to the teacher model
- Set up identity postprocessing for all SSL losses
- Added _get_required_ssl_heads(): extracts SSL loss names from training config, defaults to {"JEPA"} if
none found
- Added _ensure_identity_heads(): adds LatentPredictionHeadIdentity for any missing heads
- Updated from_pretrained() to pass cf.training_config to constructor
tests/test_encoder_teacher.py:
- Added model_without_latent_heads fixture (simulates a forecasting-only teacher)
- Added 5 new tests:
- test_frozen_teacher_adds_identity_heads_when_missing
- test_frozen_teacher_uses_training_cfg_for_heads
- test_frozen_teacher_defaults_to_jepa_without_config
- test_frozen_teacher_preserves_existing_heads
- test_frozen_teacher_all_postprocessing_is_identity
clessig
left a comment
There was a problem hiding this comment.
Some high level comments. The PR should be split up into the three independent contributions that are in it currently
| from weathergen.model.norms import AdaLayerNorm, RMSNorm | ||
|
|
||
|
|
||
| class LayerScale(nn.Module): |
There was a problem hiding this comment.
Can we put this into a separate PR--it's not related to the frozen teacher
| return x * self.gamma | ||
|
|
||
|
|
||
| class StochasticDepth(nn.Module): |
There was a problem hiding this comment.
Can we put this into a separate PR--it's not related to the frozen teacher
|
|
||
| return self.lr | ||
|
|
||
| def _set_param_group_lrs(self, base_lr: float): |
There was a problem hiding this comment.
Can we put this into a separate PR--it's not related to the frozen teacher
| lr_multiplier = g.get("lr_multiplier", 1.0) | ||
| g["lr"] = base_lr * lr_multiplier | ||
|
|
||
| def _apply_lr_multipliers(self): |
There was a problem hiding this comment.
Can we put this into a separate PR--it's not related to the frozen teacher
|
|
||
| # Initialize collapse monitor for SSL training | ||
| collapse_config = self.training_cfg.get("collapse_monitoring", {}) | ||
| self.collapse_monitor = CollapseMonitor(collapse_config, None) # device set later in run() |
There was a problem hiding this comment.
devices[9] is available here
| self.ema_model.update(self.cf.general.istep * batch_size_total, batch_size_total) | ||
|
|
||
| # Compute collapse monitoring metrics | ||
| if self.collapse_monitor.should_compute(self.cf.general.istep): |
There was a problem hiding this comment.
Move if statement into the function that you call. It's better encapsulation.
| if bidx % self.train_log_freq.metrics == 0: | ||
| self._log(TRAIN) | ||
| # Log collapse metrics | ||
| if self.collapse_monitor.should_log(self.cf.general.istep): |
There was a problem hiding this comment.
Move if statement into the function that you call. It's better encapsulation.
| is_rank_zero = is_root() | ||
|
|
||
| # Handle CompositeOptimizer (Muon+AdamW) separately | ||
| if isinstance(self.optimizer, CompositeOptimizer): |
There was a problem hiding this comment.
If we have this encapsulation, we should also encapsulate the part for native AdamW into a separate function.
| else: | ||
| return {} | ||
|
|
||
| def _get_full_composite_optimizer_state_dict(self, is_rank_zero: bool): |
There was a problem hiding this comment.
This should go to optimizer.py
|
|
||
| self.t_start = time.time() | ||
|
|
||
| def _compute_collapse_metrics(self, preds, targets_and_auxs) -> None: |
There was a problem hiding this comment.
This should not be in trainer.py
There was a problem hiding this comment.
I think it would be best to have this as a separate class and then we could also split up the function.
Description
The goal is to train against a frozen pre-trained teacher (e.g. by MAE)
Issue Number
#1815
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60