Skip to content

Conversation

@Luodian
Copy link
Collaborator

@Luodian Luodian commented Jan 10, 2026

Summary

  • Implement proper weight decay filtering following MAE/DeiT/OpenCLIP best practices
  • Exclude 1D parameters and special tokens from weight decay

Problem

The current training applies weight_decay=0.05 to ALL parameters, including:

  • 1D parameters: LayerNorm weights, RMSNorm weights, all biases
  • Special tokens: learned probe, pos_embed, cls_token, etc.

This is suboptimal because:

  • LayerNorm/RMSNorm scale parameters control feature scaling; decaying them reduces model capacity
  • Biases don't benefit from weight decay (no overfitting risk from bias terms)
  • Learned tokens like probe (shape 1,1,C) shouldn't be shrunk toward zero

Solution

Add build_adamw_param_groups() function that:

  1. Shape-based rule: Exclude param.ndim < 2 (catches all 1D params)
  2. Name-based rule: Exclude .bias suffix
  3. Token-based rule: Exclude probe, pos_embed, cls_token, mask_token, query_tokens, latents

Evidence from Major Repos

Repository Pattern
MAE p.ndim == 1 + no_weight_decay_list
DeiT timm-style no_weight_decay()
OpenCLIP Two-group AdamW, WD=0 for p.ndim < 2
SigLIP WD only on kernel (weight matrices)

Files Changed

  • training/train.py - Add helper function and modify optimizer setup

Edge Cases Handled

  • probe in pooling head: shape (1, 1, C) with ndim=3, explicitly excluded by name
  • Conv2d bias: patch embedding uses bias=False, no issue
  • MultiheadAttention biases: 1D, caught by ndim < 2 rule

Problem:
- Weight decay was applied to ALL parameters including:
  - 1D parameters (LayerNorm/RMSNorm weights, biases)
  - Special learned tokens (probe, pos_embed, cls_token)
- This is suboptimal as these parameters don't benefit from shrinkage

Solution:
- Add build_adamw_param_groups() function following MAE/DeiT/OpenCLIP patterns
- Exclude from weight decay:
  - Parameters with ndim < 2 (catches all 1D params like norm weights, biases)
  - Parameters ending in '.bias'
  - Special tokens: probe, pos_embed, cls_token, mask_token, query_tokens, latents
- Apply per-group weight decay instead of global optimizer-level decay

References:
- MAE: facebookresearch/mae/util/lr_decay.py
- OpenCLIP: excludes 'p.ndim < 2' and 'bn/ln/bias' from decay
- DeiT: timm-style no_weight_decay() returning {pos_embed, cls_token}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants