main_seg.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # ------------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # ------------------------------------------------------------------------------
  10. import argparse
  11. import os
  12. import os.path as osp
  13. import mmcv
  14. import torch
  15. import torch.backends.cudnn as cudnn
  16. import torch.distributed as dist
  17. from datasets import build_text_transform
  18. from main_group_vit import validate_seg
  19. from mmcv.image import tensor2imgs
  20. from mmcv.parallel import MMDistributedDataParallel
  21. from mmcv.runner import set_random_seed
  22. from models import build_model
  23. from omegaconf import OmegaConf, read_write
  24. from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
  25. from utils import get_config, get_logger, load_checkpoint
  26. try:
  27. # noinspection PyUnresolvedReferences
  28. from apex import amp
  29. except ImportError:
  30. amp = None
  31. def parse_args():
  32. parser = argparse.ArgumentParser('GroupViT segmentation evaluation and visualization')
  33. parser.add_argument(
  34. '--cfg',
  35. type=str,
  36. required=True,
  37. help='path to config file',
  38. )
  39. parser.add_argument(
  40. '--opts',
  41. help="Modify config options by adding 'KEY VALUE' pairs. ",
  42. default=None,
  43. nargs='+',
  44. )
  45. parser.add_argument('--resume', help='resume from checkpoint')
  46. parser.add_argument(
  47. '--output', type=str, help='root of output folder, '
  48. 'the full path is <output>/<model_name>/<tag>')
  49. parser.add_argument('--tag', help='tag of experiment')
  50. parser.add_argument(
  51. '--vis',
  52. help='Specify the visualization mode, '
  53. 'could be a list, support input, pred, input_seg, input_pred_seg_label, all_groups, first_group, last_group',
  54. default=None,
  55. nargs='+')
  56. # distributed training
  57. parser.add_argument('--local_rank', type=int, required=True, help='local rank for DistributedDataParallel')
  58. args = parser.parse_args()
  59. return args
  60. def inference(cfg):
  61. logger = get_logger()
  62. data_loader = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
  63. dataset = data_loader.dataset
  64. logger.info(f'Evaluating dataset: {dataset}')
  65. logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
  66. model = build_model(cfg.model)
  67. model.cuda()
  68. logger.info(str(model))
  69. if cfg.train.amp_opt_level != 'O0':
  70. model = amp.initialize(model, None, opt_level=cfg.train.amp_opt_level)
  71. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  72. logger.info(f'number of params: {n_parameters}')
  73. load_checkpoint(cfg, model, None, None)
  74. if 'seg' in cfg.evaluate.task:
  75. miou = validate_seg(cfg, data_loader, model)
  76. logger.info(f'mIoU of the network on the {len(data_loader.dataset)} test images: {miou:.2f}%')
  77. else:
  78. logger.info('No segmentation evaluation specified')
  79. if cfg.vis:
  80. vis_seg(cfg, data_loader, model, cfg.vis)
  81. @torch.no_grad()
  82. def vis_seg(config, data_loader, model, vis_modes):
  83. dist.barrier()
  84. model.eval()
  85. if hasattr(model, 'module'):
  86. model_without_ddp = model.module
  87. else:
  88. model_without_ddp = model
  89. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  90. seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
  91. mmddp_model = MMDistributedDataParallel(
  92. seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  93. mmddp_model.eval()
  94. model = mmddp_model.module
  95. device = next(model.parameters()).device
  96. dataset = data_loader.dataset
  97. if dist.get_rank() == 0:
  98. prog_bar = mmcv.ProgressBar(len(dataset))
  99. loader_indices = data_loader.batch_sampler
  100. for batch_indices, data in zip(loader_indices, data_loader):
  101. with torch.no_grad():
  102. result = mmddp_model(return_loss=False, **data)
  103. img_tensor = data['img'][0]
  104. img_metas = data['img_metas'][0].data[0]
  105. imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
  106. assert len(imgs) == len(img_metas)
  107. for batch_idx, img, img_meta in zip(batch_indices, imgs, img_metas):
  108. h, w, _ = img_meta['img_shape']
  109. img_show = img[:h, :w, :]
  110. ori_h, ori_w = img_meta['ori_shape'][:-1]
  111. img_show = mmcv.imresize(img_show, (ori_w, ori_h))
  112. for vis_mode in vis_modes:
  113. out_file = osp.join(config.output, 'vis_imgs', vis_mode, f'{batch_idx:04d}.jpg')
  114. model.show_result(img_show, img_tensor.to(device), result, out_file, vis_mode)
  115. if dist.get_rank() == 0:
  116. batch_size = len(result) * dist.get_world_size()
  117. for _ in range(batch_size):
  118. prog_bar.update()
  119. def main():
  120. args = parse_args()
  121. cfg = get_config(args)
  122. if cfg.train.amp_opt_level != 'O0':
  123. assert amp is not None, 'amp not installed!'
  124. with read_write(cfg):
  125. cfg.evaluate.eval_only = True
  126. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  127. rank = int(os.environ['RANK'])
  128. world_size = int(os.environ['WORLD_SIZE'])
  129. print(f'RANK and WORLD_SIZE in environ: {rank}/{world_size}')
  130. else:
  131. rank = -1
  132. world_size = -1
  133. torch.cuda.set_device(cfg.local_rank)
  134. dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
  135. dist.barrier()
  136. set_random_seed(cfg.seed, use_rank_shift=True)
  137. cudnn.benchmark = True
  138. os.makedirs(cfg.output, exist_ok=True)
  139. logger = get_logger(cfg)
  140. if dist.get_rank() == 0:
  141. path = os.path.join(cfg.output, 'config.json')
  142. OmegaConf.save(cfg, path)
  143. logger.info(f'Full config saved to {path}')
  144. # print config
  145. logger.info(OmegaConf.to_yaml(cfg))
  146. inference(cfg)
  147. dist.barrier()
  148. if __name__ == '__main__':
  149. main()