From 5495dd116b718723ee1b87ce6fb3d620200d98e4 Mon Sep 17 00:00:00 2001 From: Bo Li Date: Sat, 10 Jan 2026 16:01:27 +0800 Subject: [PATCH] fix: filter weight decay for LayerNorm, biases, and special tokens Problem: - Weight decay was applied to ALL parameters including: - 1D parameters (LayerNorm/RMSNorm weights, biases) - Special learned tokens (probe, pos_embed, cls_token) - This is suboptimal as these parameters don't benefit from shrinkage Solution: - Add build_adamw_param_groups() function following MAE/DeiT/OpenCLIP patterns - Exclude from weight decay: - Parameters with ndim < 2 (catches all 1D params like norm weights, biases) - Parameters ending in '.bias' - Special tokens: probe, pos_embed, cls_token, mask_token, query_tokens, latents - Apply per-group weight decay instead of global optimizer-level decay References: - MAE: facebookresearch/mae/util/lr_decay.py - OpenCLIP: excludes 'p.ndim < 2' and 'bn/ln/bias' from decay - DeiT: timm-style no_weight_decay() returning {pos_embed, cls_token} --- training/train.py | 62 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 6 deletions(-) diff --git a/training/train.py b/training/train.py index d3139f7..8c5d86f 100644 --- a/training/train.py +++ b/training/train.py @@ -24,6 +24,53 @@ torch._dynamo.config.optimize_ddp = False + +def build_adamw_param_groups(model, weight_decay: float, no_decay_names: set = None): + """ + Build parameter groups with proper weight decay filtering. + + Following best practices from MAE, DeiT, OpenCLIP: + - No weight decay for 1D parameters (biases, LayerNorm/RMSNorm weights) + - No weight decay for special learned tokens (probe, cls_token, pos_embed, etc.) + + Args: + model: PyTorch model + weight_decay: Weight decay value for decayed parameters + no_decay_names: Set of parameter name substrings to exclude from decay + + Returns: + List of param group dicts for optimizer + """ + if no_decay_names is None: + no_decay_names = set() + + decay_params = [] + no_decay_params = [] + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + # Check if parameter should be excluded from weight decay + # 1. 1D parameters (biases, norm weights) - shape-based rule + # 2. Special tokens/embeddings - name-based rule + should_no_decay = ( + param.ndim < 2 or + name.endswith(".bias") or + any(nd_name in name for nd_name in no_decay_names) + ) + + if should_no_decay: + no_decay_params.append(param) + else: + decay_params.append(param) + + return [ + {"params": no_decay_params, "weight_decay": 0.0}, + {"params": decay_params, "weight_decay": weight_decay}, + ] + + parser = argparse.ArgumentParser(description="Multi-dataset video training") # General @@ -291,13 +338,15 @@ def _expand(name, v): if args.finetune_backbone: backbone.requires_grad_(True) - backbone_parameters = filter(lambda p: p.requires_grad, backbone.parameters()) + # Build backbone parameter groups with proper weight decay filtering + # Exclude: 1D params (biases, norm weights), special tokens (probe, pos_embed, cls_token) + no_decay_names = {"probe", "pos_embed", "cls_token", "mask_token", "query_tokens", "latents"} + backbone_param_groups = build_adamw_param_groups(backbone, args.weight_decay, no_decay_names) dict_pfc_modules = {} list_module_pfc = [] - parameters: list[dict] = [ - {"params": backbone_parameters}, - ] + # Start with backbone param groups (already has decay/no_decay split) + parameters: list[dict] = backbone_param_groups.copy() for head_id, _ in enumerate(range(args.num_heads)): head_name = args.list_head_names[head_id] @@ -348,8 +397,9 @@ def _expand(name, v): if args.opt == "adamw": optimizer_cls = torch.optim.AdamW - - opt = optimizer_cls(parameters, lr=args.lr, weight_decay=args.weight_decay) + # Note: weight_decay is already set per-group in backbone_param_groups and PFC groups + # Pass weight_decay=0 at optimizer level to avoid double-applying + opt = optimizer_cls(parameters, lr=args.lr, weight_decay=0.0) lr_scheduler = PolynomialLRWarmup(opt, int(args.total_steps * args.warmup_ratio), args.total_steps, 2) else: raise ValueError(f"{args.opt} not support!")