Skip to content

Conversation

@Luodian
Copy link
Collaborator

@Luodian Luodian commented Jan 10, 2026

Summary

  • Fix incorrect polynomial decay formula in LR scheduler
  • Use existing closed-form calculation for correctness

Problem

The get_lr() method had a mathematically incorrect formula:

# WRONG: This is NOT polynomial decay
decay_factor = ((1.0 - (l - w) / (t - w)) / (1.0 - (l - 1 - w) / (t - w))) ** self.power
return [(1/2 * group["lr"]) * (1+decay_factor) for group in self.optimizer.param_groups]

Mathematical Proof

Let d(l) = decay_factor (the correct ratio between consecutive steps)

The code computes: g(l) = (1 + d(l)) / 2

For a decaying schedule where 0 < d(l) < 1:

  • g(l) - d(l) = (1 - d(l)) / 2 > 0
  • Therefore g(l) > d(l) always

Result: LR decays slower than the intended polynomial schedule.

Visual Impact

Step 10000: intended LR = 1e-4, actual LR = ~1.2e-4 (20% higher)
Step 50000: intended LR = 5e-5, actual LR = ~7e-5 (40% higher)

Solution

Use the existing _get_closed_form_lr() which correctly computes:

base_lr * (1 - (step - warmup) / (total - warmup)) ** power

Additional Benefits

  • Numerical stability: Computes directly from base_lrs instead of accumulating incremental updates
  • Simpler code: Removes 15 lines of incorrect logic

Files Changed

  • training/lr_scheduler.py - Replace get_lr() implementation

Problem:
The get_lr() method had an incorrect formula for polynomial decay:
  LR(l) = (1/2 * group['lr']) * (1 + decay_factor)

This is mathematically NOT equivalent to polynomial decay.

Mathematical proof:
- Let d(l) = decay_factor = f(l)/f(l-1) (the correct ratio)
- Current formula computes: g(l) = (1 + d(l)) / 2
- For decaying schedule, 0 < d(l) < 1
- Therefore g(l) > d(l) always
- Result: LR decays SLOWER than intended polynomial schedule

The correct polynomial decay should be:
  LR(l) = base_lr * (1 - progress)^power

Solution:
Use the existing _get_closed_form_lr() which correctly computes:
  base_lr * (1 - (step - warmup) / (total - warmup))^power

This also improves numerical stability by computing directly from base_lrs
rather than accumulating incremental updates over millions of steps.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants