Skip to content

Conversation

@klei22
Copy link
Collaborator

@klei22 klei22 commented Feb 9, 2026

This pull request introduces a new configuration sweep YAML for stability experiments and refactors how rankme/areq metrics are computed and logged in train.py. The refactor ensures that these metrics are only computed and logged when explicitly requested, improving efficiency and clarity. Additionally, a minor update disables flash attention when a non-standard softmax variant is used.

Experiment configuration:

  • Added explorations/relu2max_stability_sweep.yaml, providing a comprehensive sweep configuration using named groups for various model architecture options, such as normalization, position embeddings, and activation variants. This enables systematic experimentation with different model settings.

Rankme/AREQ metrics computation and logging:

  • Refactored train.py to introduce a compute_rankme flag, ensuring that rankme and areq metrics are only computed, stored, and logged when log_rankme or log_areq arguments are set. This affects metric collection, output dictionaries, tensorboard logging, and CSV export. [1] [2] [3] [4] [5] [6] [7] [8] [9] [10] [11] [12]

Attention mechanism update:

  • In variations/attention_variations.py, flash attention is now automatically disabled (with a printed message) when a non-standard softmax variant is selected, ensuring compatibility and clarity during model initialization.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new sweep config for stability experiments and refactors RankMe/aReQ metric collection so those metrics are only computed/propagated when enabled via args, plus a small attention initialization tweak to disable flash attention when using non-standard softmax variants.

Changes:

  • Added explorations/relu2max_stability_sweep.yaml to run a structured stability sweep using named static groups.
  • Refactored train.py to gate RankMe/aReQ vector collection, metric computation, and logging/export behind a single compute_rankme condition.
  • Updated variations/attention_variations.py to force-disable flash attention when softmax_variant_attn != 'softmax'.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
variations/attention_variations.py Disables flash attention automatically when a non-softmax attention normalization is selected.
train.py Avoids collecting/returning/logging RankMe/aReQ unless requested by --log_rankme/--log_areq.
explorations/relu2max_stability_sweep.yaml Introduces a stability sweep configuration combining architectural named groups with a relu2max-vs-softmax comparison.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@torch.no_grad()
def estimate_loss(self):
out = {'datasets':{}}
compute_rankme = self.args.log_rankme or self.args.log_areq
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new compute_rankme flag gates both RankMe and aReQ computation/logging, so the name is a bit misleading (it reads like only RankMe). Consider renaming to something like compute_rankme_areq / compute_repr_metrics to make the intent clearer and avoid future misuse.

Copilot uses AI. Check for mistakes.
Comment on lines +1149 to +1151
if not self.disable_flash_attention:
print("flash attention removed due to softmax alternative")
self.disable_flash_attention = True
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This print() will trigger once per attention module instance (e.g., every layer), which can spam logs for deep models/sweeps. Consider using the repo’s logging mechanism (or warnings.warn) and/or ensuring the message is emitted only once per process/config.

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,66 @@
# default.yaml using named group mechanisms from sample.yaml
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The header comment still says “default.yaml…”, but this file is a dedicated relu2max stability sweep config. Updating the comment to match the file’s purpose will avoid confusion when browsing/duplicating sweep configs.

Suggested change
# default.yaml using named group mechanisms from sample.yaml
# relu2max_stability_sweep.yaml using named group mechanisms from sample.yaml

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant