main_seg.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # ------------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # ------------------------------------------------------------------------------
  13. import argparse
  14. import os
  15. import os.path as osp
  16. import mmcv
  17. import torch
  18. import torch.backends.cudnn as cudnn
  19. import torch.distributed as dist
  20. from datasets import build_text_transform
  21. from main_group_vit import validate_seg
  22. from mmcv.image import tensor2imgs
  23. from mmcv.parallel import MMDistributedDataParallel
  24. from mmcv.runner import set_random_seed
  25. from models import build_model
  26. from omegaconf import OmegaConf, read_write
  27. from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
  28. from utils import get_config, get_logger, load_checkpoint
  29. try:
  30. # noinspection PyUnresolvedReferences
  31. from apex import amp
  32. except ImportError:
  33. amp = None
  34. def parse_args():
  35. parser = argparse.ArgumentParser('GroupViT segmentation evaluation and visualization')
  36. parser.add_argument(
  37. '--cfg',
  38. type=str,
  39. required=True,
  40. help='path to config file',
  41. )
  42. parser.add_argument(
  43. '--opts',
  44. help="Modify config options by adding 'KEY VALUE' pairs. ",
  45. default=None,
  46. nargs='+',
  47. )
  48. parser.add_argument('--resume', help='resume from checkpoint')
  49. parser.add_argument(
  50. '--output', type=str, help='root of output folder, '
  51. 'the full path is <output>/<model_name>/<tag>')
  52. parser.add_argument('--tag', help='tag of experiment')
  53. parser.add_argument(
  54. '--vis',
  55. help='Specify the visualization mode, '
  56. 'could be a list, support input, pred, input_seg, input_pred_seg_label, all_groups, first_group, last_group',
  57. default=None,
  58. nargs='+')
  59. # distributed training
  60. parser.add_argument('--local_rank', type=int, required=True, help='local rank for DistributedDataParallel')
  61. args = parser.parse_args()
  62. return args
  63. def inference(cfg):
  64. logger = get_logger()
  65. data_loader = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
  66. dataset = data_loader.dataset
  67. logger.info(f'Evaluating dataset: {dataset}')
  68. logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
  69. model = build_model(cfg.model)
  70. model.cuda()
  71. logger.info(str(model))
  72. if cfg.train.amp_opt_level != 'O0':
  73. model = amp.initialize(model, None, opt_level=cfg.train.amp_opt_level)
  74. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  75. logger.info(f'number of params: {n_parameters}')
  76. load_checkpoint(cfg, model, None, None)
  77. if 'seg' in cfg.evaluate.task:
  78. miou = validate_seg(cfg, data_loader, model)
  79. logger.info(f'mIoU of the network on the {len(data_loader.dataset)} test images: {miou:.2f}%')
  80. else:
  81. logger.info('No segmentation evaluation specified')
  82. if cfg.vis:
  83. vis_seg(cfg, data_loader, model, cfg.vis)
  84. @torch.no_grad()
  85. def vis_seg(config, data_loader, model, vis_modes):
  86. dist.barrier()
  87. model.eval()
  88. if hasattr(model, 'module'):
  89. model_without_ddp = model.module
  90. else:
  91. model_without_ddp = model
  92. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  93. seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
  94. mmddp_model = MMDistributedDataParallel(
  95. seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  96. mmddp_model.eval()
  97. model = mmddp_model.module
  98. device = next(model.parameters()).device
  99. dataset = data_loader.dataset
  100. if dist.get_rank() == 0:
  101. prog_bar = mmcv.ProgressBar(len(dataset))
  102. loader_indices = data_loader.batch_sampler
  103. for batch_indices, data in zip(loader_indices, data_loader):
  104. with torch.no_grad():
  105. result = mmddp_model(return_loss=False, **data)
  106. img_tensor = data['img'][0]
  107. img_metas = data['img_metas'][0].data[0]
  108. imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
  109. assert len(imgs) == len(img_metas)
  110. for batch_idx, img, img_meta in zip(batch_indices, imgs, img_metas):
  111. h, w, _ = img_meta['img_shape']
  112. img_show = img[:h, :w, :]
  113. ori_h, ori_w = img_meta['ori_shape'][:-1]
  114. img_show = mmcv.imresize(img_show, (ori_w, ori_h))
  115. for vis_mode in vis_modes:
  116. out_file = osp.join(config.output, 'vis_imgs', vis_mode, f'{batch_idx:04d}.jpg')
  117. model.show_result(img_show, img_tensor.to(device), result, out_file, vis_mode)
  118. if dist.get_rank() == 0:
  119. batch_size = len(result) * dist.get_world_size()
  120. for _ in range(batch_size):
  121. prog_bar.update()
  122. def main():
  123. args = parse_args()
  124. cfg = get_config(args)
  125. if cfg.train.amp_opt_level != 'O0':
  126. assert amp is not None, 'amp not installed!'
  127. with read_write(cfg):
  128. cfg.evaluate.eval_only = True
  129. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  130. rank = int(os.environ['RANK'])
  131. world_size = int(os.environ['WORLD_SIZE'])
  132. print(f'RANK and WORLD_SIZE in environ: {rank}/{world_size}')
  133. else:
  134. rank = -1
  135. world_size = -1
  136. torch.cuda.set_device(cfg.local_rank)
  137. dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
  138. dist.barrier()
  139. set_random_seed(cfg.seed, use_rank_shift=True)
  140. cudnn.benchmark = True
  141. os.makedirs(cfg.output, exist_ok=True)
  142. logger = get_logger(cfg)
  143. if dist.get_rank() == 0:
  144. path = os.path.join(cfg.output, 'config.json')
  145. OmegaConf.save(cfg, path)
  146. logger.info(f'Full config saved to {path}')
  147. # print config
  148. logger.info(OmegaConf.to_yaml(cfg))
  149. inference(cfg)
  150. dist.barrier()
  151. if __name__ == '__main__':
  152. main()