checkpoint.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # -------------------------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. #
  5. # MIT License
  6. #
  7. # Permission is hereby granted, free of charge, to any person obtaining a copy
  8. # of this software and associated documentation files (the "Software"), to deal
  9. # in the Software without restriction, including without limitation the rights
  10. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11. # copies of the Software, and to permit persons to whom the Software is
  12. # furnished to do so, subject to the following conditions:
  13. #
  14. # The above copyright notice and this permission notice shall be included in all
  15. # copies or substantial portions of the Software.
  16. #
  17. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  23. # SOFTWARE
  24. #
  25. # Written by Ze Liu, Zhenda Xie
  26. # Modified by Jiarui Xu
  27. # -------------------------------------------------------------------------
  28. # Modified by Jilan Xu
  29. # -------------------------------------------------------------------------
  30. import os
  31. from collections import defaultdict
  32. import torch
  33. import torch.distributed as dist
  34. from mmcv.runner import CheckpointLoader
  35. from omegaconf import read_write
  36. from .logger import get_logger
  37. try:
  38. from apex import amp
  39. except ImportError:
  40. amp = None
  41. def load_checkpoint_stage1(config, model):
  42. logger = get_logger()
  43. logger.info(f'==============> Resuming stage1 checkpoint from {config.checkpoint.resume}....................')
  44. checkpoint = CheckpointLoader.load_checkpoint(config.checkpoint.stage1_checkpoint, map_location='cpu')
  45. ### load online model parameters ###
  46. new_state_dict = {}
  47. new_params = ['logit_scale_mask']
  48. for k, v in model.state_dict().items():
  49. if k in new_params:
  50. continue
  51. if k in checkpoint['model']:
  52. new_state_dict[k] = checkpoint['model'][k]
  53. else:
  54. oldk = k.replace('img_encoder_momentum', 'img_encoder')
  55. if oldk in checkpoint['model']:
  56. new_state_dict[k] = checkpoint['model'][oldk]
  57. msg = model.load_state_dict(new_state_dict, strict=False)
  58. logger.info(msg)
  59. del checkpoint
  60. torch.cuda.empty_cache()
  61. def load_checkpoint(config, model, optimizer, lr_scheduler):
  62. logger = get_logger()
  63. logger.info(f'==============> Resuming from {config.checkpoint.resume}....................')
  64. checkpoint = CheckpointLoader.load_checkpoint(config.checkpoint.resume, map_location='cpu')
  65. msg = model.load_state_dict(checkpoint['model'], strict=False)
  66. logger.info(msg)
  67. metrics = defaultdict(float)
  68. if (not config.evaluate.eval_only and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint
  69. and 'epoch' in checkpoint):
  70. optimizer.load_state_dict(checkpoint['optimizer'])
  71. lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  72. with read_write(config):
  73. config.train.start_epoch = checkpoint['epoch'] + 1
  74. if 'amp' in checkpoint and config.train.amp_opt_level != 'O0' and checkpoint[
  75. 'config'].train.amp_opt_level != 'O0':
  76. amp.load_state_dict(checkpoint['amp'])
  77. logger.info(f"=> loaded successfully '{config.checkpoint.resume}' (epoch {checkpoint['epoch']})")
  78. metrics = checkpoint['metrics']
  79. del checkpoint
  80. torch.cuda.empty_cache()
  81. return metrics
  82. def save_checkpoint(config, epoch, model, metrics, optimizer, lr_scheduler, suffix=''):
  83. save_state = {
  84. 'model': model.state_dict(),
  85. 'optimizer': optimizer.state_dict(),
  86. 'lr_scheduler': lr_scheduler.state_dict(),
  87. 'metrics': metrics,
  88. 'epoch': epoch,
  89. 'config': config
  90. }
  91. logger = get_logger()
  92. for k, v in metrics.items():
  93. save_state[k] = v
  94. if config.train.amp_opt_level != 'O0':
  95. save_state['amp'] = amp.state_dict()
  96. if len(suffix) > 0 and not suffix.startswith('_') and suffix != 'best_miou':
  97. suffix = '_' + suffix
  98. if epoch >= 10 and epoch % 10 == 0 and suffix != 'best_miou':
  99. filename = f'ckpt_epoch_{epoch}{suffix}.pth'
  100. save_path = os.path.join(config.output, filename)
  101. torch.save(save_state, save_path)
  102. ##### this is for per epoch saving, easy for resuming #####
  103. if suffix == 'best_miou':
  104. print('saving best iou checkpoint')
  105. filename = 'best_miou.pth' # only save the best one
  106. current_save_path = os.path.join(config.output, filename)
  107. torch.save(save_state, current_save_path)
  108. logger.info(f'{current_save_path} saved for best iou!!!')
  109. else:
  110. current_save_path = os.path.join(config.output, 'checkpoint.pth')
  111. torch.save(save_state, current_save_path)
  112. logger.info(f'{current_save_path} saved !!!')
  113. def get_grad_norm(parameters, norm_type=2):
  114. if isinstance(parameters, torch.Tensor):
  115. parameters = [parameters]
  116. parameters = list(filter(lambda p: p.grad is not None, parameters))
  117. norm_type = float(norm_type)
  118. total_norm = 0
  119. for p in parameters:
  120. param_norm = p.grad.data.norm(norm_type)
  121. total_norm += param_norm.item()**norm_type
  122. total_norm = total_norm**(1. / norm_type)
  123. return total_norm
  124. def auto_resume_helper(output_dir):
  125. if os.path.exists(os.path.join(output_dir, 'checkpoint.pth')):
  126. return os.path.join(output_dir, 'checkpoint.pth')
  127. checkpoints = os.listdir(output_dir)
  128. checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
  129. print(f'All checkpoints founded in {output_dir}: {checkpoints}')
  130. if len(checkpoints) > 0:
  131. latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
  132. print(f'The latest checkpoint founded: {latest_checkpoint}')
  133. resume_file = latest_checkpoint
  134. else:
  135. resume_file = None
  136. return resume_file
  137. def reduce_tensor(tensor):
  138. rt = tensor.clone()
  139. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  140. rt /= dist.get_world_size()
  141. return rt