From 88e035ce6267a82afdf1929d1597c1424b06e33a Mon Sep 17 00:00:00 2001 From: Bo Li Date: Sat, 10 Jan 2026 16:02:49 +0800 Subject: [PATCH] fix: use closed-form LR calculation to fix polynomial decay formula bug Problem: The get_lr() method had an incorrect formula for polynomial decay: LR(l) = (1/2 * group['lr']) * (1 + decay_factor) This is mathematically NOT equivalent to polynomial decay. Mathematical proof: - Let d(l) = decay_factor = f(l)/f(l-1) (the correct ratio) - Current formula computes: g(l) = (1 + d(l)) / 2 - For decaying schedule, 0 < d(l) < 1 - Therefore g(l) > d(l) always - Result: LR decays SLOWER than intended polynomial schedule The correct polynomial decay should be: LR(l) = base_lr * (1 - progress)^power Solution: Use the existing _get_closed_form_lr() which correctly computes: base_lr * (1 - (step - warmup) / (total - warmup))^power This also improves numerical stability by computing directly from base_lrs rather than accumulating incremental updates over millions of steps. --- training/lr_scheduler.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/training/lr_scheduler.py b/training/lr_scheduler.py index cbbac5a..9af43bc 100644 --- a/training/lr_scheduler.py +++ b/training/lr_scheduler.py @@ -27,22 +27,11 @@ def get_lr(self): UserWarning, ) - if self.last_epoch == 0 or self.last_epoch > self.total_iters: - return [group["lr"] for group in self.optimizer.param_groups] - - if self.last_epoch <= self.warmup_iters: - return [ - base_lr * self.last_epoch / self.warmup_iters - for base_lr in self.base_lrs - ] - else: - l = self.last_epoch - w = self.warmup_iters - t = self.total_iters - decay_factor = ( - (1.0 - (l - w) / (t - w)) / (1.0 - (l - 1 - w) / (t - w)) - ) ** self.power - return [(1/2 * group["lr"]) * (1+decay_factor) for group in self.optimizer.param_groups] + # Use closed-form calculation for correctness and numerical stability + # Previous implementation had a formula bug: (1/2 * lr) * (1 + decay_factor) + # is NOT equivalent to polynomial decay. It causes slower decay than intended. + # The closed-form directly computes: base_lr * (1 - progress)^power + return self._get_closed_form_lr() def _get_closed_form_lr(self): if self.last_epoch <= self.warmup_iters: