checkpoint.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # -------------------------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. #
  5. # MIT License
  6. #
  7. # Permission is hereby granted, free of charge, to any person obtaining a copy
  8. # of this software and associated documentation files (the "Software"), to deal
  9. # in the Software without restriction, including without limitation the rights
  10. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11. # copies of the Software, and to permit persons to whom the Software is
  12. # furnished to do so, subject to the following conditions:
  13. #
  14. # The above copyright notice and this permission notice shall be included in all
  15. # copies or substantial portions of the Software.
  16. #
  17. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  23. # SOFTWARE
  24. #
  25. # Written by Ze Liu, Zhenda Xie
  26. # Modified by Jiarui Xu
  27. # -------------------------------------------------------------------------
  28. import os
  29. from collections import defaultdict
  30. import torch
  31. import torch.distributed as dist
  32. from mmcv.runner import CheckpointLoader
  33. from omegaconf import read_write
  34. from .logger import get_logger
  35. try:
  36. # noinspection PyUnresolvedReferences
  37. from apex import amp
  38. except ImportError:
  39. amp = None
  40. def load_checkpoint(config, model, optimizer, lr_scheduler):
  41. logger = get_logger()
  42. logger.info(f'==============> Resuming form {config.checkpoint.resume}....................')
  43. checkpoint = CheckpointLoader.load_checkpoint(config.checkpoint.resume, map_location='cpu')
  44. msg = model.load_state_dict(checkpoint['model'], strict=False)
  45. logger.info(msg)
  46. metrics = defaultdict(float)
  47. if (not config.evaluate.eval_only and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint
  48. and 'epoch' in checkpoint):
  49. optimizer.load_state_dict(checkpoint['optimizer'])
  50. lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  51. with read_write(config):
  52. config.train.start_epoch = checkpoint['epoch'] + 1
  53. if 'amp' in checkpoint and config.train.amp_opt_level != 'O0' and checkpoint[
  54. 'config'].train.amp_opt_level != 'O0':
  55. amp.load_state_dict(checkpoint['amp'])
  56. logger.info(f"=> loaded successfully '{config.checkpoint.resume}' (epoch {checkpoint['epoch']})")
  57. metrics = checkpoint['metrics']
  58. del checkpoint
  59. torch.cuda.empty_cache()
  60. return metrics
  61. def save_checkpoint(config, epoch, model, metrics, optimizer, lr_scheduler, suffix=''):
  62. save_state = {
  63. 'model': model.state_dict(),
  64. 'optimizer': optimizer.state_dict(),
  65. 'lr_scheduler': lr_scheduler.state_dict(),
  66. 'metrics': metrics,
  67. 'epoch': epoch,
  68. 'config': config
  69. }
  70. logger = get_logger()
  71. for k, v in metrics.items():
  72. save_state[k] = v
  73. if config.train.amp_opt_level != 'O0':
  74. save_state['amp'] = amp.state_dict()
  75. if len(suffix) > 0 and not suffix.startswith('_'):
  76. suffix = '_' + suffix
  77. filename = f'ckpt_epoch_{epoch}{suffix}.pth'
  78. save_path = os.path.join(config.output, filename)
  79. logger.info(f'{save_path} saving......')
  80. torch.save(save_state, save_path)
  81. torch.save(save_state, os.path.join(config.output, 'checkpoint.pth'))
  82. logger.info(f'{save_path} saved !!!')
  83. if config.checkpoint.max_kept > 0:
  84. if epoch >= config.checkpoint.max_kept:
  85. logger.info(f'Epoch: {epoch}, greater than config.checkpoint.max_kept: {config.checkpoint.max_kept}')
  86. end_clean_epoch = epoch - config.checkpoint.max_kept
  87. old_path_list = []
  88. for cur_clean_epoch in range(end_clean_epoch + 1):
  89. old_path = os.path.join(config.output, f'ckpt_epoch_{cur_clean_epoch}{suffix}.pth')
  90. if os.path.exists(old_path):
  91. logger.info(f'old checkpoint path {old_path} exits')
  92. old_path_list.append(old_path)
  93. for old_path in old_path_list[:-config.checkpoint.max_kept]:
  94. os.remove(old_path)
  95. logger.info(f'old checkpoint path {old_path} removed!!!')
  96. def get_grad_norm(parameters, norm_type=2):
  97. if isinstance(parameters, torch.Tensor):
  98. parameters = [parameters]
  99. parameters = list(filter(lambda p: p.grad is not None, parameters))
  100. norm_type = float(norm_type)
  101. total_norm = 0
  102. for p in parameters:
  103. param_norm = p.grad.data.norm(norm_type)
  104. total_norm += param_norm.item()**norm_type
  105. total_norm = total_norm**(1. / norm_type)
  106. return total_norm
  107. def auto_resume_helper(output_dir):
  108. if os.path.exists(os.path.join(output_dir, 'checkpoint.pth')):
  109. return os.path.join(output_dir, 'checkpoint.pth')
  110. checkpoints = os.listdir(output_dir)
  111. checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
  112. print(f'All checkpoints founded in {output_dir}: {checkpoints}')
  113. if len(checkpoints) > 0:
  114. latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
  115. print(f'The latest checkpoint founded: {latest_checkpoint}')
  116. resume_file = latest_checkpoint
  117. else:
  118. resume_file = None
  119. return resume_file
  120. def reduce_tensor(tensor):
  121. rt = tensor.clone()
  122. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  123. rt /= dist.get_world_size()
  124. return rt