123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- # -------------------------------------------------------------------------
- # Swin Transformer
- # Copyright (c) 2021 Microsoft
- #
- # MIT License
- #
- # Permission is hereby granted, free of charge, to any person obtaining a copy
- # of this software and associated documentation files (the "Software"), to deal
- # in the Software without restriction, including without limitation the rights
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- # copies of the Software, and to permit persons to whom the Software is
- # furnished to do so, subject to the following conditions:
- #
- # The above copyright notice and this permission notice shall be included in all
- # copies or substantial portions of the Software.
- #
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- # SOFTWARE
- #
- # Written by Ze Liu, Zhenda Xie
- # Modified by Jiarui Xu
- # -------------------------------------------------------------------------
- # Modified by Jilan Xu
- # -------------------------------------------------------------------------
- import os
- from collections import defaultdict
- import torch
- import torch.distributed as dist
- from mmcv.runner import CheckpointLoader
- from omegaconf import read_write
- from .logger import get_logger
- try:
- from apex import amp
- except ImportError:
- amp = None
- def load_checkpoint_stage1(config, model):
- logger = get_logger()
- logger.info(f'==============> Resuming stage1 checkpoint from {config.checkpoint.resume}....................')
- checkpoint = CheckpointLoader.load_checkpoint(config.checkpoint.stage1_checkpoint, map_location='cpu')
- ### load online model parameters ###
- new_state_dict = {}
- new_params = ['logit_scale_mask']
- for k, v in model.state_dict().items():
- if k in new_params:
- continue
- if k in checkpoint['model']:
- new_state_dict[k] = checkpoint['model'][k]
- else:
- oldk = k.replace('img_encoder_momentum', 'img_encoder')
- if oldk in checkpoint['model']:
- new_state_dict[k] = checkpoint['model'][oldk]
-
- msg = model.load_state_dict(new_state_dict, strict=False)
- logger.info(msg)
-
- del checkpoint
- torch.cuda.empty_cache()
- def load_checkpoint(config, model, optimizer, lr_scheduler):
- logger = get_logger()
- logger.info(f'==============> Resuming from {config.checkpoint.resume}....................')
- checkpoint = CheckpointLoader.load_checkpoint(config.checkpoint.resume, map_location='cpu')
- msg = model.load_state_dict(checkpoint['model'], strict=False)
- logger.info(msg)
- metrics = defaultdict(float)
- if (not config.evaluate.eval_only and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint
- and 'epoch' in checkpoint):
- optimizer.load_state_dict(checkpoint['optimizer'])
- lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
- with read_write(config):
- config.train.start_epoch = checkpoint['epoch'] + 1
- if 'amp' in checkpoint and config.train.amp_opt_level != 'O0' and checkpoint[
- 'config'].train.amp_opt_level != 'O0':
- amp.load_state_dict(checkpoint['amp'])
- logger.info(f"=> loaded successfully '{config.checkpoint.resume}' (epoch {checkpoint['epoch']})")
- metrics = checkpoint['metrics']
- del checkpoint
- torch.cuda.empty_cache()
- return metrics
- def save_checkpoint(config, epoch, model, metrics, optimizer, lr_scheduler, suffix=''):
- save_state = {
- 'model': model.state_dict(),
- 'optimizer': optimizer.state_dict(),
- 'lr_scheduler': lr_scheduler.state_dict(),
- 'metrics': metrics,
- 'epoch': epoch,
- 'config': config
- }
- logger = get_logger()
- for k, v in metrics.items():
- save_state[k] = v
- if config.train.amp_opt_level != 'O0':
- save_state['amp'] = amp.state_dict()
- if len(suffix) > 0 and not suffix.startswith('_') and suffix != 'best_miou':
- suffix = '_' + suffix
-
- if epoch >= 10 and epoch % 10 == 0 and suffix != 'best_miou':
- filename = f'ckpt_epoch_{epoch}{suffix}.pth'
- save_path = os.path.join(config.output, filename)
- torch.save(save_state, save_path)
- ##### this is for per epoch saving, easy for resuming #####
- if suffix == 'best_miou':
- print('saving best iou checkpoint')
- filename = 'best_miou.pth' # only save the best one
- current_save_path = os.path.join(config.output, filename)
- torch.save(save_state, current_save_path)
- logger.info(f'{current_save_path} saved for best iou!!!')
- else:
- current_save_path = os.path.join(config.output, 'checkpoint.pth')
- torch.save(save_state, current_save_path)
- logger.info(f'{current_save_path} saved !!!')
- def get_grad_norm(parameters, norm_type=2):
- if isinstance(parameters, torch.Tensor):
- parameters = [parameters]
- parameters = list(filter(lambda p: p.grad is not None, parameters))
- norm_type = float(norm_type)
- total_norm = 0
- for p in parameters:
- param_norm = p.grad.data.norm(norm_type)
- total_norm += param_norm.item()**norm_type
- total_norm = total_norm**(1. / norm_type)
- return total_norm
- def auto_resume_helper(output_dir):
- if os.path.exists(os.path.join(output_dir, 'checkpoint.pth')):
- return os.path.join(output_dir, 'checkpoint.pth')
- checkpoints = os.listdir(output_dir)
- checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
- print(f'All checkpoints founded in {output_dir}: {checkpoints}')
- if len(checkpoints) > 0:
- latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
- print(f'The latest checkpoint founded: {latest_checkpoint}')
- resume_file = latest_checkpoint
- else:
- resume_file = None
- return resume_file
- def reduce_tensor(tensor):
- rt = tensor.clone()
- dist.all_reduce(rt, op=dist.ReduceOp.SUM)
- rt /= dist.get_world_size()
- return rt
|