Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chameleon/base/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
OPTIMIZERS.register_module(name=k, force=True, module=globals()[k])


__all__ += ['PolynomialLRWarmup', 'WrappedLRScheduler']
__all__ += ['PolynomialLRWarmup', 'WrappedLRScheduler', 'MultiStepLRWarmUp']
98 changes: 61 additions & 37 deletions chameleon/base/optim/polynomial_lr_warmup.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,79 @@
import warnings
from typing import List

from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from ...registry import OPTIMIZERS


@OPTIMIZERS.register_module()
class PolynomialLRWarmup(_LRScheduler):
class PolynomialLRWarmup(LRScheduler):
"""
Scheduler with an initial linear warm-up followed by polynomial decay.

- For the first `warmup_iters` steps, LR increases linearly
from 0 -> base_lr.
- For steps `warmup_iters < step <= total_iters`, LR decays as
base_lr * (1 - (step - warmup_iters) / (total_iters - warmup_iters))^power.
- After `total_iters`, LR is held at the final decayed value.

Args:
optimizer (Optimizer): Wrapped optimizer.
warmup_iters (int): Number of steps for linear warm-up; must be ≥ 0.
total_iters (int): Total number of steps for warm-up + decay; must be ≥ warmup_iters.
power (float): Exponent for polynomial decay. Default: 1.0 (linear).
last_epoch (int): The index of last step. Default: -1 (start from step 0).
verbose (bool): If True, prints a message for each LR update.
"""

def __init__(
self,
optimizer,
warmup_iters,
total_iters=5,
power=1.0,
last_epoch=-1,
verbose=False
optimizer: Optimizer,
warmup_iters: int,
total_iters: int,
power: float = 1.0,
last_epoch: int = -1,
):
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
# input validation
if warmup_iters < 0:
raise ValueError(f"warmup_iters must be >= 0, got {warmup_iters}")
if total_iters < warmup_iters:
raise ValueError(
f"total_iters ({total_iters}) must be >= warmup_iters ({warmup_iters})")
if power < 0:
raise ValueError(f"power must be non-negative, got {power}")

self.warmup_iters = warmup_iters
self.total_iters = total_iters
self.power = power
self.warmup_iters = warmup_iters

def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
super().__init__(optimizer, last_epoch)

if self.last_epoch == 0 or self.last_epoch > self.total_iters:
return [group["lr"] for group in self.optimizer.param_groups]
def get_closed_form(self) -> List[float]:
"""
Compute the learning rate for the current `last_epoch` in closed form.
Called by the base class when you use the chainable API: scheduler.step().
"""
# Clamp epoch to [0, total_iters]
epoch = min(max(self.last_epoch, 0), self.total_iters)

if self.last_epoch <= self.warmup_iters:
return [base_lr * self.last_epoch / self.warmup_iters for base_lr in self.base_lrs]
else:
l = self.last_epoch
w = self.warmup_iters
t = self.total_iters
decay_factor = ((1.0 - (l - w) / (t - w)) /
(1.0 - (l - 1 - w) / (t - w))) ** self.power
return [group["lr"] * decay_factor for group in self.optimizer.param_groups]

def _get_closed_form_lr(self):

if self.last_epoch <= self.warmup_iters:
return [
base_lr * self.last_epoch / self.warmup_iters for base_lr in self.base_lrs]
else:
# 1) Warm-up phase
if epoch <= self.warmup_iters:
return [
(
base_lr * (1.0 - (min(self.total_iters, self.last_epoch) - self.warmup_iters) / (
self.total_iters - self.warmup_iters)) ** self.power
)
base_lr *
(epoch / self.warmup_iters if self.warmup_iters > 0 else 1.0)
for base_lr in self.base_lrs
]

# 2) Polynomial decay phase
decay_steps = epoch - self.warmup_iters
decay_total = self.total_iters - self.warmup_iters
factor = (1.0 - decay_steps / decay_total) ** self.power
return [base_lr * factor for base_lr in self.base_lrs]

def get_lr(self) -> List[float]:
"""
Legacy step API. If you’re still calling scheduler.step(epoch),
this will be invoked instead of get_closed_form().
"""
return self.get_closed_form()
162 changes: 122 additions & 40 deletions chameleon/base/optim/warm_up.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,150 @@
from typing import List
from typing import List, Optional

from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import LRScheduler, MultiStepLR

# Project-specific registry (keep as-is or remove if unused)
from ...registry import OPTIMIZERS

__all__ = ['WrappedLRScheduler']
__all__ = ["WrappedLRScheduler", "MultiStepLRWarmUp"]


@OPTIMIZERS.register_module()
class WrappedLRScheduler(_LRScheduler):
class WrappedLRScheduler(LRScheduler):
"""
Gradually warm-up(increasing) learning rate in optimizer.
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
Gradual warmup scheduler.

During the first `milestone` steps (or epochs), the learning rate
increases linearly from 0 (or base_lr) up to base_lr * multiplier.
After warmup completes, scheduling is delegated to `after_scheduler`.

Args:
optimizer (Optimizer): Wrapped optimizer.
milestone (int):
milestone step for warm-up.
multiplier (float):
A factor to multiply base_lr.
if multiplier > 1.0, learning rate = base lr * multiplier.
if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
after_scheduler (lr_scheduler):
after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
milestone (int): Number of steps (or epochs) for warmup; must be > 0.
multiplier (float, optional): Final LR = base_lr * multiplier.
- If multiplier == 1.0, warmup goes from 0 -> base_lr.
- If multiplier > 1.0, warmup goes from base_lr -> base_lr * multiplier.
after_scheduler (LRScheduler, optional): Scheduler to use after warmup.
last_epoch (int, optional): The index of last epoch. Default: -1.
verbose (bool, optional): If True, prints a message to stdout for
each update. Default: False.
"""

def __init__(
self,
optimizer: Optimizer,
milestone: int,
multiplier: float = 1.0,
after_scheduler: _LRScheduler = None,
interval='step'
after_scheduler: Optional[LRScheduler] = None,
last_epoch: int = -1
):
self.multiplier = multiplier
if self.multiplier < 1.:
raise ValueError('multiplier should be greater thant or equal to 1.')
if milestone <= 0:
raise ValueError("milestone must be > 0.")
if multiplier < 1.0:
raise ValueError("multiplier must be >= 1.0.")

self.milestone = milestone
self.multiplier = multiplier
self.after_scheduler = after_scheduler
self.finished = False
self.interval = interval
super().__init__(optimizer) # need be set in the end of __init__

# Initialize base class with optimizer, last_epoch, and verbose
super().__init__(optimizer, last_epoch)

def get_lr(self):
# do after_scheduler
if self.last_epoch > self.milestone:
if self.after_scheduler:
if not self.finished:
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
self.finished = True
return self.after_scheduler.get_last_lr()
return [base_lr * self.multiplier for base_lr in self.base_lrs]

if self.multiplier == 1.0:
return [base_lr * (float(self.last_epoch) / self.milestone) for base_lr in self.base_lrs]
else:
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.milestone + 1.) for base_lr in self.base_lrs]
# During warmup phase
if self.last_epoch <= self.milestone:
if self.multiplier == 1.0:
# Linear increase: 0 -> base_lr
return [
base_lr * (self.last_epoch / self.milestone)
for base_lr in self.base_lrs
]
else:
# Linear increase: base_lr -> base_lr * multiplier
return [
base_lr * ((self.multiplier - 1.0) *
self.last_epoch / self.milestone + 1.0)
for base_lr in self.base_lrs
]

# After warmup completes
if self.after_scheduler is not None:
# On first transition, reset the after_scheduler's base_lrs
if not self.finished:
self.after_scheduler.base_lrs = [
base_lr * self.multiplier for base_lr in self.base_lrs
]
self.finished = True
# Delegate to after_scheduler
return self.after_scheduler.get_last_lr()

# No after_scheduler: keep LR at base_lr * multiplier
return [base_lr * self.multiplier for base_lr in self.base_lrs]

def step(self, epoch: Optional[int] = None, metrics: Optional[float] = None):
"""
Update the learning rate.

def step(self, epoch=None, metrics=None):
if self.finished and self.after_scheduler:
if epoch is None:
self.after_scheduler.step(None)
If warmup is finished and an after_scheduler is provided,
delegate the step to after_scheduler. Otherwise, call the
base class step() to continue warmup.

Args:
epoch (int, optional): Current epoch or step index.
metrics (float, optional): Metric for ReduceLROnPlateau.
"""
if self.finished and self.after_scheduler is not None:
# If using ReduceLROnPlateau (metric-based), pass metrics first
if metrics is not None and "plateau" in type(self.after_scheduler).__name__.lower():
self.after_scheduler.step(
metrics, epoch - self.milestone if epoch is not None else None)
else:
self.after_scheduler.step(epoch - self.milestone)
# Standard scheduler.step(epoch)
self.after_scheduler.step(
epoch - self.milestone if epoch is not None else None)
# Sync the last learning rates
self._last_lr = self.after_scheduler.get_last_lr()
else:
return super().step()
# Still in warmup or no after_scheduler: use base class logic
super().step(epoch)


@OPTIMIZERS.register_module(is_model_builder=True)
def MultiStepLRWarmUp(
optimizer: Optimizer,
milestones: List[int],
warmup_milestone: int,
gamma: float = 0.1,
last_epoch: int = -1
) -> WrappedLRScheduler:
"""
Factory function to create a warmup + MultiStepLR scheduler.

Args:
optimizer (Optimizer): Wrapped optimizer.
milestones (List[int]): List of epoch indices where LR is decayed by gamma.
warmup_milestone (int): Number of epochs for linear warmup.
gamma (float, optional): Multiplicative LR decay factor for MultiStepLR. Default: 0.1.
last_epoch (int, optional): Index of last epoch. Default: -1 (start from scratch).

Returns:
WrappedLRScheduler: Scheduler that linearly warms up for `warmup_milestone`
epochs, then delegates to MultiStepLR.
"""
# 1) create the MultiStepLR scheduler that will run *after* warmup
multi_step = MultiStepLR(
optimizer=optimizer,
milestones=milestones,
gamma=gamma,
last_epoch=last_epoch,
)

# 2) wrap it with linear warmup
return WrappedLRScheduler(
optimizer=optimizer,
milestone=warmup_milestone,
multiplier=1.0, # warmup from 0 -> base_lr
after_scheduler=multi_step,
last_epoch=last_epoch,
)
Loading