-
Notifications
You must be signed in to change notification settings - Fork 28
Add stability sweep #744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add stability sweep #744
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Adds a new sweep config for stability experiments and refactors RankMe/aReQ metric collection so those metrics are only computed/propagated when enabled via args, plus a small attention initialization tweak to disable flash attention when using non-standard softmax variants.
Changes:
- Added
explorations/relu2max_stability_sweep.yamlto run a structured stability sweep using named static groups. - Refactored
train.pyto gate RankMe/aReQ vector collection, metric computation, and logging/export behind a singlecompute_rankmecondition. - Updated
variations/attention_variations.pyto force-disable flash attention whensoftmax_variant_attn != 'softmax'.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
variations/attention_variations.py |
Disables flash attention automatically when a non-softmax attention normalization is selected. |
train.py |
Avoids collecting/returning/logging RankMe/aReQ unless requested by --log_rankme/--log_areq. |
explorations/relu2max_stability_sweep.yaml |
Introduces a stability sweep configuration combining architectural named groups with a relu2max-vs-softmax comparison. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @torch.no_grad() | ||
| def estimate_loss(self): | ||
| out = {'datasets':{}} | ||
| compute_rankme = self.args.log_rankme or self.args.log_areq |
Copilot
AI
Feb 9, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new compute_rankme flag gates both RankMe and aReQ computation/logging, so the name is a bit misleading (it reads like only RankMe). Consider renaming to something like compute_rankme_areq / compute_repr_metrics to make the intent clearer and avoid future misuse.
| if not self.disable_flash_attention: | ||
| print("flash attention removed due to softmax alternative") | ||
| self.disable_flash_attention = True |
Copilot
AI
Feb 9, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This print() will trigger once per attention module instance (e.g., every layer), which can spam logs for deep models/sweeps. Consider using the repo’s logging mechanism (or warnings.warn) and/or ensuring the message is emitted only once per process/config.
| @@ -0,0 +1,66 @@ | |||
| # default.yaml using named group mechanisms from sample.yaml | |||
Copilot
AI
Feb 9, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The header comment still says “default.yaml…”, but this file is a dedicated relu2max stability sweep config. Updating the comment to match the file’s purpose will avoid confusion when browsing/duplicating sweep configs.
| # default.yaml using named group mechanisms from sample.yaml | |
| # relu2max_stability_sweep.yaml using named group mechanisms from sample.yaml |
This pull request introduces a new configuration sweep YAML for stability experiments and refactors how rankme/areq metrics are computed and logged in
train.py. The refactor ensures that these metrics are only computed and logged when explicitly requested, improving efficiency and clarity. Additionally, a minor update disables flash attention when a non-standard softmax variant is used.Experiment configuration:
explorations/relu2max_stability_sweep.yaml, providing a comprehensive sweep configuration using named groups for various model architecture options, such as normalization, position embeddings, and activation variants. This enables systematic experimentation with different model settings.Rankme/AREQ metrics computation and logging:
train.pyto introduce acompute_rankmeflag, ensuring that rankme and areq metrics are only computed, stored, and logged whenlog_rankmeorlog_areqarguments are set. This affects metric collection, output dictionaries, tensorboard logging, and CSV export. [1] [2] [3] [4] [5] [6] [7] [8] [9] [10] [11] [12]Attention mechanism update:
variations/attention_variations.py, flash attention is now automatically disabled (with a printed message) when a non-standard softmax variant is selected, ensuring compatibility and clarity during model initialization.