demo_seg.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # -------------------------------------------------------------------------
  2. # # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # -------------------------------------------------------------------------
  10. import argparse
  11. import os.path as osp
  12. import sys
  13. parentdir = osp.dirname(osp.dirname(__file__))
  14. sys.path.insert(0, parentdir)
  15. import mmcv
  16. import torch
  17. from datasets import build_text_transform
  18. from mmcv.cnn.utils import revert_sync_batchnorm
  19. from mmcv.image import tensor2imgs
  20. from mmcv.parallel import collate, scatter
  21. from models import build_model
  22. from omegaconf import read_write
  23. from segmentation.datasets import COCOObjectDataset, PascalContextDataset, PascalVOCDataset
  24. from segmentation.evaluation import build_seg_demo_pipeline, build_seg_inference
  25. from utils import get_config, load_checkpoint
  26. def parse_args():
  27. parser = argparse.ArgumentParser('GroupViT demo')
  28. parser.add_argument(
  29. '--cfg',
  30. type=str,
  31. required=True,
  32. help='path to config file',
  33. )
  34. parser.add_argument(
  35. '--opts',
  36. help="Modify config options by adding 'KEY VALUE' pairs. ",
  37. default=None,
  38. nargs='+',
  39. )
  40. parser.add_argument('--resume', help='resume from checkpoint')
  41. parser.add_argument(
  42. '--vis',
  43. help='Specify the visualization mode, '
  44. 'could be a list, support "input", "pred", "input_pred", "all_groups", "first_group", "final_group", "input_pred_label"',
  45. default=None,
  46. nargs='+')
  47. parser.add_argument('--device', default='cuda:0', help='Device used for inference')
  48. parser.add_argument(
  49. '--dataset', default='voc', choices=['voc', 'coco', 'context'], help='dataset classes for visualization')
  50. parser.add_argument('--input', type=str, help='input image path')
  51. parser.add_argument('--output_dir', type=str, help='output dir')
  52. args = parser.parse_args()
  53. args.local_rank = 0 # compatible with config
  54. return args
  55. def inference(args, cfg):
  56. model = build_model(cfg.model)
  57. model = revert_sync_batchnorm(model)
  58. model.to(args.device)
  59. model.eval()
  60. load_checkpoint(cfg, model, None, None)
  61. text_transform = build_text_transform(False, cfg.data.text_aug, with_dc=False)
  62. if args.dataset == 'voc':
  63. dataset_class = PascalVOCDataset
  64. seg_cfg = 'segmentation/configs/_base_/datasets/pascal_voc12.py'
  65. elif args.dataset == 'coco':
  66. dataset_class = COCOObjectDataset
  67. seg_cfg = 'segmentation/configs/_base_/datasets/coco_object164k.py'
  68. elif args.dataset == 'context':
  69. dataset_class = PascalContextDataset
  70. seg_cfg = 'segmentation/configs/_base_/datasets/pascal_context.py'
  71. else:
  72. raise ValueError('Unknown dataset: {}'.format(args.dataset))
  73. with read_write(cfg):
  74. cfg.evaluate.seg.cfg = seg_cfg
  75. cfg.evaluate.seg.opts = ['test_cfg.mode=whole']
  76. seg_model = build_seg_inference(model, dataset_class, text_transform, cfg.evaluate.seg)
  77. vis_seg(seg_model, args.input, args.output_dir, args.vis)
  78. def vis_seg(seg_model, input_img, output_dir, vis_modes):
  79. device = next(seg_model.parameters()).device
  80. test_pipeline = build_seg_demo_pipeline()
  81. # prepare data
  82. data = dict(img=input_img)
  83. data = test_pipeline(data)
  84. data = collate([data], samples_per_gpu=1)
  85. if next(seg_model.parameters()).is_cuda:
  86. # scatter to specified GPU
  87. data = scatter(data, [device])[0]
  88. else:
  89. data['img_metas'] = [i.data[0] for i in data['img_metas']]
  90. with torch.no_grad():
  91. result = seg_model(return_loss=False, rescale=True, **data)
  92. img_tensor = data['img'][0]
  93. img_metas = data['img_metas'][0]
  94. imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
  95. assert len(imgs) == len(img_metas)
  96. for img, img_meta in zip(imgs, img_metas):
  97. h, w, _ = img_meta['img_shape']
  98. img_show = img[:h, :w, :]
  99. ori_h, ori_w = img_meta['ori_shape'][:-1]
  100. img_show = mmcv.imresize(img_show, (ori_w, ori_h))
  101. for vis_mode in vis_modes:
  102. out_file = osp.join(output_dir, 'vis_imgs', vis_mode, f'{vis_mode}.jpg')
  103. seg_model.show_result(img_show, img_tensor.to(device), result, out_file, vis_mode)
  104. def main():
  105. args = parse_args()
  106. cfg = get_config(args)
  107. with read_write(cfg):
  108. cfg.evaluate.eval_only = True
  109. inference(args, cfg)
  110. if __name__ == '__main__':
  111. main()