diff --git a/training/train.py b/training/train.py index d3139f7..8c5d86f 100644 --- a/training/train.py +++ b/training/train.py @@ -24,6 +24,53 @@ torch._dynamo.config.optimize_ddp = False + +def build_adamw_param_groups(model, weight_decay: float, no_decay_names: set = None): + """ + Build parameter groups with proper weight decay filtering. + + Following best practices from MAE, DeiT, OpenCLIP: + - No weight decay for 1D parameters (biases, LayerNorm/RMSNorm weights) + - No weight decay for special learned tokens (probe, cls_token, pos_embed, etc.) + + Args: + model: PyTorch model + weight_decay: Weight decay value for decayed parameters + no_decay_names: Set of parameter name substrings to exclude from decay + + Returns: + List of param group dicts for optimizer + """ + if no_decay_names is None: + no_decay_names = set() + + decay_params = [] + no_decay_params = [] + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + # Check if parameter should be excluded from weight decay + # 1. 1D parameters (biases, norm weights) - shape-based rule + # 2. Special tokens/embeddings - name-based rule + should_no_decay = ( + param.ndim < 2 or + name.endswith(".bias") or + any(nd_name in name for nd_name in no_decay_names) + ) + + if should_no_decay: + no_decay_params.append(param) + else: + decay_params.append(param) + + return [ + {"params": no_decay_params, "weight_decay": 0.0}, + {"params": decay_params, "weight_decay": weight_decay}, + ] + + parser = argparse.ArgumentParser(description="Multi-dataset video training") # General @@ -291,13 +338,15 @@ def _expand(name, v): if args.finetune_backbone: backbone.requires_grad_(True) - backbone_parameters = filter(lambda p: p.requires_grad, backbone.parameters()) + # Build backbone parameter groups with proper weight decay filtering + # Exclude: 1D params (biases, norm weights), special tokens (probe, pos_embed, cls_token) + no_decay_names = {"probe", "pos_embed", "cls_token", "mask_token", "query_tokens", "latents"} + backbone_param_groups = build_adamw_param_groups(backbone, args.weight_decay, no_decay_names) dict_pfc_modules = {} list_module_pfc = [] - parameters: list[dict] = [ - {"params": backbone_parameters}, - ] + # Start with backbone param groups (already has decay/no_decay split) + parameters: list[dict] = backbone_param_groups.copy() for head_id, _ in enumerate(range(args.num_heads)): head_name = args.list_head_names[head_id] @@ -348,8 +397,9 @@ def _expand(name, v): if args.opt == "adamw": optimizer_cls = torch.optim.AdamW - - opt = optimizer_cls(parameters, lr=args.lr, weight_decay=args.weight_decay) + # Note: weight_decay is already set per-group in backbone_param_groups and PFC groups + # Pass weight_decay=0 at optimizer level to avoid double-applying + opt = optimizer_cls(parameters, lr=args.lr, weight_decay=0.0) lr_scheduler = PolynomialLRWarmup(opt, int(args.total_steps * args.warmup_ratio), args.total_steps, 2) else: raise ValueError(f"{args.opt} not support!")