123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- from bisect import bisect_right
- from math import cos, pi
- from torch.optim.lr_scheduler import _LRScheduler
- class LRSchedulerWithWarmup(_LRScheduler):
- def __init__(
- self,
- optimizer,
- milestones,
- gamma=0.1,
- mode="step",
- warmup_factor=1.0 / 3,
- warmup_epochs=10,
- warmup_method="linear",
- total_epochs=100,
- target_lr=0,
- power=0.9,
- last_epoch=-1,
- ):
- if not list(milestones) == sorted(milestones):
- raise ValueError(
- "Milestones should be a list of"
- " increasing integers. Got {}".format(milestones),
- )
- if mode not in ("step", "exp", "poly", "cosine", "linear"):
- raise ValueError(
- "Only 'step', 'exp', 'poly' or 'cosine' learning rate scheduler accepted"
- "got {}".format(mode)
- )
- if warmup_method not in ("constant", "linear"):
- raise ValueError(
- "Only 'constant' or 'linear' warmup_method accepted"
- "got {}".format(warmup_method)
- )
- self.milestones = milestones
- self.mode = mode
- self.gamma = gamma
- self.warmup_factor = warmup_factor
- self.warmup_epochs = warmup_epochs
- self.warmup_method = warmup_method
- self.total_epochs = total_epochs
- self.target_lr = target_lr
- self.power = power
- super().__init__(optimizer, last_epoch)
- def get_lr(self):
- if self.last_epoch < self.warmup_epochs:
- if self.warmup_method == "constant":
- warmup_factor = self.warmup_factor
- elif self.warmup_method == "linear":
- alpha = self.last_epoch / self.warmup_epochs
- warmup_factor = self.warmup_factor * (1 - alpha) + alpha
- return [base_lr * warmup_factor for base_lr in self.base_lrs]
- if self.mode == "step":
- return [
- base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
- for base_lr in self.base_lrs
- ]
- epoch_ratio = (self.last_epoch - self.warmup_epochs) / (
- self.total_epochs - self.warmup_epochs
- )
- if self.mode == "exp":
- factor = epoch_ratio
- return [base_lr * self.power ** factor for base_lr in self.base_lrs]
- if self.mode == "linear":
- factor = 1 - epoch_ratio
- return [base_lr * factor for base_lr in self.base_lrs]
- if self.mode == "poly":
- factor = 1 - epoch_ratio
- return [
- self.target_lr + (base_lr - self.target_lr) * self.power ** factor
- for base_lr in self.base_lrs
- ]
- if self.mode == "cosine":
- factor = 0.5 * (1 + cos(pi * epoch_ratio))
- return [
- self.target_lr + (base_lr - self.target_lr) * factor
- for base_lr in self.base_lrs
- ]
- raise NotImplementedError
|