optimizer.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # -------------------------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. #
  5. # MIT License
  6. #
  7. # Permission is hereby granted, free of charge, to any person obtaining a copy
  8. # of this software and associated documentation files (the "Software"), to deal
  9. # in the Software without restriction, including without limitation the rights
  10. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11. # copies of the Software, and to permit persons to whom the Software is
  12. # furnished to do so, subject to the following conditions:
  13. #
  14. # The above copyright notice and this permission notice shall be included in all
  15. # copies or substantial portions of the Software.
  16. #
  17. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  23. # SOFTWARE
  24. #
  25. # Written by Ze Liu, Zhenda Xie
  26. # Modified by Jiarui Xu
  27. # -------------------------------------------------------------------------
  28. from torch import optim as optim
  29. def build_optimizer(config, model):
  30. """Build optimizer, set weight decay of normalization to 0 by default."""
  31. parameters = set_weight_decay(model, {}, {})
  32. opt_name = config.optimizer.name
  33. optimizer = None
  34. if opt_name == 'adamw':
  35. optimizer = optim.AdamW(
  36. parameters,
  37. eps=config.optimizer.eps,
  38. betas=config.optimizer.betas,
  39. lr=config.base_lr,
  40. weight_decay=config.weight_decay)
  41. else:
  42. raise ValueError(f'Unsupported optimizer: {opt_name}')
  43. return optimizer
  44. def set_weight_decay(model, skip_list=(), skip_keywords=()):
  45. has_decay = []
  46. no_decay = []
  47. for name, param in model.named_parameters():
  48. if not param.requires_grad:
  49. continue # frozen weights
  50. if len(param.shape) == 1 or name.endswith('.bias') or (name in skip_list) or \
  51. check_keywords_in_name(name, skip_keywords):
  52. no_decay.append(param)
  53. # print(f"{name} has no weight decay")
  54. else:
  55. has_decay.append(param)
  56. return [{'params': has_decay}, {'params': no_decay, 'weight_decay': 0.}]
  57. def check_keywords_in_name(name, keywords=()):
  58. isin = False
  59. for keyword in keywords:
  60. if keyword in name:
  61. isin = True
  62. return isin