123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- # -------------------------------------------------------------------------
- # # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
- #
- # This work is made available under the Nvidia Source Code License.
- # To view a copy of this license, visit
- # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
- #
- # Written by Jiarui Xu
- # -------------------------------------------------------------------------
- import argparse
- import os.path as osp
- import sys
- parentdir = osp.dirname(osp.dirname(__file__))
- sys.path.insert(0, parentdir)
- import mmcv
- import torch
- from datasets import build_text_transform
- from mmcv.cnn.utils import revert_sync_batchnorm
- from mmcv.image import tensor2imgs
- from mmcv.parallel import collate, scatter
- from models import build_model
- from omegaconf import read_write
- from segmentation.datasets import COCOObjectDataset, PascalContextDataset, PascalVOCDataset
- from segmentation.evaluation import build_seg_demo_pipeline, build_seg_inference
- from utils import get_config, load_checkpoint
- def parse_args():
- parser = argparse.ArgumentParser('GroupViT demo')
- parser.add_argument(
- '--cfg',
- type=str,
- required=True,
- help='path to config file',
- )
- parser.add_argument(
- '--opts',
- help="Modify config options by adding 'KEY VALUE' pairs. ",
- default=None,
- nargs='+',
- )
- parser.add_argument('--resume', help='resume from checkpoint')
- parser.add_argument(
- '--vis',
- help='Specify the visualization mode, '
- 'could be a list, support "input", "pred", "input_pred", "all_groups", "first_group", "final_group", "input_pred_label"',
- default=None,
- nargs='+')
- parser.add_argument('--device', default='cuda:0', help='Device used for inference')
- parser.add_argument(
- '--dataset', default='voc', choices=['voc', 'coco', 'context'], help='dataset classes for visualization')
- parser.add_argument('--input', type=str, help='input image path')
- parser.add_argument('--output_dir', type=str, help='output dir')
- args = parser.parse_args()
- args.local_rank = 0 # compatible with config
- return args
- def inference(args, cfg):
- model = build_model(cfg.model)
- model = revert_sync_batchnorm(model)
- model.to(args.device)
- model.eval()
- load_checkpoint(cfg, model, None, None)
- text_transform = build_text_transform(False, cfg.data.text_aug, with_dc=False)
- if args.dataset == 'voc':
- dataset_class = PascalVOCDataset
- seg_cfg = 'segmentation/configs/_base_/datasets/pascal_voc12.py'
- elif args.dataset == 'coco':
- dataset_class = COCOObjectDataset
- seg_cfg = 'segmentation/configs/_base_/datasets/coco_object164k.py'
- elif args.dataset == 'context':
- dataset_class = PascalContextDataset
- seg_cfg = 'segmentation/configs/_base_/datasets/pascal_context.py'
- else:
- raise ValueError('Unknown dataset: {}'.format(args.dataset))
- with read_write(cfg):
- cfg.evaluate.seg.cfg = seg_cfg
- cfg.evaluate.seg.opts = ['test_cfg.mode=whole']
- seg_model = build_seg_inference(model, dataset_class, text_transform, cfg.evaluate.seg)
- vis_seg(seg_model, args.input, args.output_dir, args.vis)
- def vis_seg(seg_model, input_img, output_dir, vis_modes):
- device = next(seg_model.parameters()).device
- test_pipeline = build_seg_demo_pipeline()
- # prepare data
- data = dict(img=input_img)
- data = test_pipeline(data)
- data = collate([data], samples_per_gpu=1)
- if next(seg_model.parameters()).is_cuda:
- # scatter to specified GPU
- data = scatter(data, [device])[0]
- else:
- data['img_metas'] = [i.data[0] for i in data['img_metas']]
- with torch.no_grad():
- result = seg_model(return_loss=False, rescale=True, **data)
- img_tensor = data['img'][0]
- img_metas = data['img_metas'][0]
- imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
- assert len(imgs) == len(img_metas)
- for img, img_meta in zip(imgs, img_metas):
- h, w, _ = img_meta['img_shape']
- img_show = img[:h, :w, :]
- ori_h, ori_w = img_meta['ori_shape'][:-1]
- img_show = mmcv.imresize(img_show, (ori_w, ori_h))
- for vis_mode in vis_modes:
- out_file = osp.join(output_dir, 'vis_imgs', vis_mode, f'{vis_mode}.jpg')
- seg_model.show_result(img_show, img_tensor.to(device), result, out_file, vis_mode)
- def main():
- args = parse_args()
- cfg = get_config(args)
- with read_write(cfg):
- cfg.evaluate.eval_only = True
- inference(args, cfg)
- if __name__ == '__main__':
- main()
|