Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,53 @@

torch._dynamo.config.optimize_ddp = False


def build_adamw_param_groups(model, weight_decay: float, no_decay_names: set = None):
"""
Build parameter groups with proper weight decay filtering.

Following best practices from MAE, DeiT, OpenCLIP:
- No weight decay for 1D parameters (biases, LayerNorm/RMSNorm weights)
- No weight decay for special learned tokens (probe, cls_token, pos_embed, etc.)

Args:
model: PyTorch model
weight_decay: Weight decay value for decayed parameters
no_decay_names: Set of parameter name substrings to exclude from decay

Returns:
List of param group dicts for optimizer
"""
if no_decay_names is None:
no_decay_names = set()

decay_params = []
no_decay_params = []

for name, param in model.named_parameters():
if not param.requires_grad:
continue

# Check if parameter should be excluded from weight decay
# 1. 1D parameters (biases, norm weights) - shape-based rule
# 2. Special tokens/embeddings - name-based rule
should_no_decay = (
param.ndim < 2 or
name.endswith(".bias") or
any(nd_name in name for nd_name in no_decay_names)
)

if should_no_decay:
no_decay_params.append(param)
else:
decay_params.append(param)

return [
{"params": no_decay_params, "weight_decay": 0.0},
{"params": decay_params, "weight_decay": weight_decay},
]


parser = argparse.ArgumentParser(description="Multi-dataset video training")

# General
Expand Down Expand Up @@ -291,13 +338,15 @@ def _expand(name, v):
if args.finetune_backbone:
backbone.requires_grad_(True)

backbone_parameters = filter(lambda p: p.requires_grad, backbone.parameters())
# Build backbone parameter groups with proper weight decay filtering
# Exclude: 1D params (biases, norm weights), special tokens (probe, pos_embed, cls_token)
no_decay_names = {"probe", "pos_embed", "cls_token", "mask_token", "query_tokens", "latents"}
backbone_param_groups = build_adamw_param_groups(backbone, args.weight_decay, no_decay_names)

dict_pfc_modules = {}
list_module_pfc = []
parameters: list[dict] = [
{"params": backbone_parameters},
]
# Start with backbone param groups (already has decay/no_decay split)
parameters: list[dict] = backbone_param_groups.copy()

for head_id, _ in enumerate(range(args.num_heads)):
head_name = args.list_head_names[head_id]
Expand Down Expand Up @@ -348,8 +397,9 @@ def _expand(name, v):

if args.opt == "adamw":
optimizer_cls = torch.optim.AdamW

opt = optimizer_cls(parameters, lr=args.lr, weight_decay=args.weight_decay)
# Note: weight_decay is already set per-group in backbone_param_groups and PFC groups
# Pass weight_decay=0 at optimizer level to avoid double-applying
opt = optimizer_cls(parameters, lr=args.lr, weight_decay=0.0)
lr_scheduler = PolynomialLRWarmup(opt, int(args.total_steps * args.warmup_ratio), args.total_steps, 2)
else:
raise ValueError(f"{args.opt} not support!")
Expand Down