checkpoint.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # -------------------------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. #
  5. # MIT License
  6. #
  7. # Permission is hereby granted, free of charge, to any person obtaining a copy
  8. # of this software and associated documentation files (the "Software"), to deal
  9. # in the Software without restriction, including without limitation the rights
  10. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11. # copies of the Software, and to permit persons to whom the Software is
  12. # furnished to do so, subject to the following conditions:
  13. #
  14. # The above copyright notice and this permission notice shall be included in all
  15. # copies or substantial portions of the Software.
  16. #
  17. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  23. # SOFTWARE
  24. #
  25. # Written by Ze Liu, Zhenda Xie
  26. # Modified by Jiarui Xu
  27. # -------------------------------------------------------------------------
  28. # Modified by Jilan Xu
  29. # -------------------------------------------------------------------------
  30. import os
  31. from collections import defaultdict
  32. import torch
  33. import torch.distributed as dist
  34. from mmcv.runner import CheckpointLoader
  35. from omegaconf import read_write
  36. from .logger import get_logger
  37. from ipdb import set_trace
  38. try:
  39. # noinspection PyUnresolvedReferences
  40. from apex import amp
  41. except ImportError:
  42. amp = None
  43. def load_checkpoint_stage1(config, model):
  44. logger = get_logger()
  45. logger.info(f'==============> Resuming stage1 checkpoint from {config.checkpoint.resume}....................')
  46. checkpoint = CheckpointLoader.load_checkpoint(config.checkpoint.stage1_checkpoint, map_location='cpu')
  47. ### load online model parameters ###
  48. # msg = model.load_state_dict(checkpoint['model'], strict=False)
  49. new_state_dict = {}
  50. new_params = ['logit_scale_mask']
  51. for k, v in model.state_dict().items():
  52. if k in new_params:
  53. continue
  54. if k in checkpoint['model']:
  55. new_state_dict[k] = checkpoint['model'][k]
  56. else:
  57. oldk = k.replace('img_encoder_momentum', 'img_encoder')
  58. # new_state_dict[k] = checkpoint['model'][oldk]
  59. if oldk in checkpoint['model']:
  60. new_state_dict[k] = checkpoint['model'][oldk]
  61. msg = model.load_state_dict(new_state_dict, strict=False)
  62. logger.info(msg)
  63. del checkpoint
  64. torch.cuda.empty_cache()
  65. def load_checkpoint(config, model, optimizer, lr_scheduler):
  66. logger = get_logger()
  67. logger.info(f'==============> Resuming from {config.checkpoint.resume}....................')
  68. checkpoint = CheckpointLoader.load_checkpoint(config.checkpoint.resume, map_location='cpu')
  69. msg = model.load_state_dict(checkpoint['model'], strict=False)
  70. logger.info(msg)
  71. metrics = defaultdict(float)
  72. if (not config.evaluate.eval_only and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint
  73. and 'epoch' in checkpoint):
  74. optimizer.load_state_dict(checkpoint['optimizer'])
  75. lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  76. with read_write(config):
  77. config.train.start_epoch = checkpoint['epoch'] + 1
  78. if 'amp' in checkpoint and config.train.amp_opt_level != 'O0' and checkpoint[
  79. 'config'].train.amp_opt_level != 'O0':
  80. amp.load_state_dict(checkpoint['amp'])
  81. logger.info(f"=> loaded successfully '{config.checkpoint.resume}' (epoch {checkpoint['epoch']})")
  82. metrics = checkpoint['metrics']
  83. del checkpoint
  84. torch.cuda.empty_cache()
  85. return metrics
  86. def save_checkpoint(config, epoch, model, metrics, optimizer, lr_scheduler, suffix=''):
  87. save_state = {
  88. 'model': model.state_dict(),
  89. 'optimizer': optimizer.state_dict(),
  90. 'lr_scheduler': lr_scheduler.state_dict(),
  91. 'metrics': metrics,
  92. 'epoch': epoch,
  93. 'config': config
  94. }
  95. logger = get_logger()
  96. for k, v in metrics.items():
  97. save_state[k] = v
  98. if config.train.amp_opt_level != 'O0':
  99. save_state['amp'] = amp.state_dict()
  100. if len(suffix) > 0 and not suffix.startswith('_') and suffix != 'best_miou':
  101. suffix = '_' + suffix
  102. if epoch >= 10 and epoch % 10 == 0 and suffix != 'best_miou':
  103. filename = f'ckpt_epoch_{epoch}{suffix}.pth'
  104. save_path = os.path.join(config.output, filename)
  105. torch.save(save_state, save_path)
  106. ##### this is for per epoch saving, easy for resuming #####
  107. # filename = f'ckpt_epoch_{suffix}.pth' # only save the best one
  108. # save_path = os.path.join(config.output, filename)
  109. # logger.info(f'{save_path} saving......')
  110. if suffix == 'best_miou':
  111. print('saving best iou checkpoint')
  112. filename = 'best_miou.pth' # only save the best one
  113. current_save_path = os.path.join(config.output, filename)
  114. torch.save(save_state, current_save_path)
  115. logger.info(f'{current_save_path} saved for best iou!!!')
  116. else:
  117. current_save_path = os.path.join(config.output, 'checkpoint.pth')
  118. torch.save(save_state, current_save_path)
  119. logger.info(f'{current_save_path} saved !!!')
  120. # if config.checkpoint.max_kept > 0:
  121. # if epoch >= config.checkpoint.max_kept:
  122. # logger.info(f'Epoch: {epoch}, greater than config.checkpoint.max_kept: {config.checkpoint.max_kept}')
  123. # end_clean_epoch = epoch - config.checkpoint.max_kept
  124. # old_path_list = []
  125. # for cur_clean_epoch in range(end_clean_epoch + 1):
  126. # old_path = os.path.join(config.output, f'ckpt_epoch_{cur_clean_epoch}{suffix}.pth')
  127. # if os.path.exists(old_path):
  128. # logger.info(f'old checkpoint path {old_path} exits')
  129. # old_path_list.append(old_path)
  130. # for old_path in old_path_list[:-config.checkpoint.max_kept]:
  131. # os.remove(old_path)
  132. # logger.info(f'old checkpoint path {old_path} removed!!!')
  133. def get_grad_norm(parameters, norm_type=2):
  134. if isinstance(parameters, torch.Tensor):
  135. parameters = [parameters]
  136. parameters = list(filter(lambda p: p.grad is not None, parameters))
  137. norm_type = float(norm_type)
  138. total_norm = 0
  139. for p in parameters:
  140. param_norm = p.grad.data.norm(norm_type)
  141. total_norm += param_norm.item()**norm_type
  142. total_norm = total_norm**(1. / norm_type)
  143. return total_norm
  144. def auto_resume_helper(output_dir):
  145. if os.path.exists(os.path.join(output_dir, 'checkpoint.pth')):
  146. return os.path.join(output_dir, 'checkpoint.pth')
  147. checkpoints = os.listdir(output_dir)
  148. checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
  149. print(f'All checkpoints founded in {output_dir}: {checkpoints}')
  150. if len(checkpoints) > 0:
  151. latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
  152. print(f'The latest checkpoint founded: {latest_checkpoint}')
  153. resume_file = latest_checkpoint
  154. else:
  155. resume_file = None
  156. return resume_file
  157. def reduce_tensor(tensor):
  158. rt = tensor.clone()
  159. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  160. rt /= dist.get_world_size()
  161. return rt