lr_scheduler.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from bisect import bisect_right
  2. from math import cos, pi
  3. from torch.optim.lr_scheduler import _LRScheduler
  4. class LRSchedulerWithWarmup(_LRScheduler):
  5. def __init__(
  6. self,
  7. optimizer,
  8. milestones,
  9. gamma=0.1,
  10. mode="step",
  11. warmup_factor=1.0 / 3,
  12. warmup_epochs=10,
  13. warmup_method="linear",
  14. total_epochs=100,
  15. target_lr=0,
  16. power=0.9,
  17. last_epoch=-1,
  18. ):
  19. if not list(milestones) == sorted(milestones):
  20. raise ValueError(
  21. "Milestones should be a list of"
  22. " increasing integers. Got {}".format(milestones),
  23. )
  24. if mode not in ("step", "exp", "poly", "cosine", "linear"):
  25. raise ValueError(
  26. "Only 'step', 'exp', 'poly' or 'cosine' learning rate scheduler accepted"
  27. "got {}".format(mode)
  28. )
  29. if warmup_method not in ("constant", "linear"):
  30. raise ValueError(
  31. "Only 'constant' or 'linear' warmup_method accepted"
  32. "got {}".format(warmup_method)
  33. )
  34. self.milestones = milestones
  35. self.mode = mode
  36. self.gamma = gamma
  37. self.warmup_factor = warmup_factor
  38. self.warmup_epochs = warmup_epochs
  39. self.warmup_method = warmup_method
  40. self.total_epochs = total_epochs
  41. self.target_lr = target_lr
  42. self.power = power
  43. super().__init__(optimizer, last_epoch)
  44. def get_lr(self):
  45. if self.last_epoch < self.warmup_epochs:
  46. if self.warmup_method == "constant":
  47. warmup_factor = self.warmup_factor
  48. elif self.warmup_method == "linear":
  49. alpha = self.last_epoch / self.warmup_epochs
  50. warmup_factor = self.warmup_factor * (1 - alpha) + alpha
  51. return [base_lr * warmup_factor for base_lr in self.base_lrs]
  52. if self.mode == "step":
  53. return [
  54. base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
  55. for base_lr in self.base_lrs
  56. ]
  57. epoch_ratio = (self.last_epoch - self.warmup_epochs) / (
  58. self.total_epochs - self.warmup_epochs
  59. )
  60. if self.mode == "exp":
  61. factor = epoch_ratio
  62. return [base_lr * self.power ** factor for base_lr in self.base_lrs]
  63. if self.mode == "linear":
  64. factor = 1 - epoch_ratio
  65. return [base_lr * factor for base_lr in self.base_lrs]
  66. if self.mode == "poly":
  67. factor = 1 - epoch_ratio
  68. return [
  69. self.target_lr + (base_lr - self.target_lr) * self.power ** factor
  70. for base_lr in self.base_lrs
  71. ]
  72. if self.mode == "cosine":
  73. factor = 0.5 * (1 + cos(pi * epoch_ratio))
  74. return [
  75. self.target_lr + (base_lr - self.target_lr) * factor
  76. for base_lr in self.base_lrs
  77. ]
  78. raise NotImplementedError