main_demo.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. # -------------------------------------------------------------------------
  2. # Written by Jilan Xu
  3. # -------------------------------------------------------------------------
  4. import argparse
  5. import os
  6. import os.path as osp
  7. import subprocess
  8. import mmcv
  9. import torch
  10. import torch.backends.cudnn as cudnn
  11. import torch.distributed as dist
  12. from datasets import build_text_transform
  13. from main_pretrain import validate_seg
  14. from mmcv.image import tensor2imgs
  15. from mmcv.parallel import MMDistributedDataParallel
  16. from mmcv.runner import set_random_seed
  17. from models import build_model
  18. from omegaconf import OmegaConf, read_write
  19. from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_custom_seg_dataset, build_seg_inference, build_demo_inference
  20. from utils import get_config, get_logger, load_checkpoint
  21. from transformers import AutoTokenizer, RobertaTokenizer
  22. from ipdb import set_trace
  23. from main_pretrain import init_distributed_mode
  24. try:
  25. # noinspection PyUnresolvedReferences
  26. from apex import amp
  27. except ImportError:
  28. amp = None
  29. tokenizer_dict = {
  30. 'Bert': AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False),
  31. # 'Roberta': RobertaTokenizer.from_pretrained('/mnt/petrelfs/xujilan/roberta-base/'),
  32. 'Roberta': RobertaTokenizer.from_pretrained('roberta-base'),
  33. 'TextTransformer': None,
  34. }
  35. def parse_args():
  36. parser = argparse.ArgumentParser('OVSegmentor segmentation demo')
  37. parser.add_argument(
  38. '--cfg',
  39. type=str,
  40. default='./configs/test_voc12.yml',
  41. help='path to config file',
  42. )
  43. parser.add_argument(
  44. '--resume',
  45. type=str,
  46. required=True,
  47. help='resume from checkpoint',
  48. )
  49. parser.add_argument(
  50. '--image_folder',
  51. type=str,
  52. required=True,
  53. help='path to the input image folder',
  54. )
  55. parser.add_argument(
  56. '--vocab',
  57. help='could be a list of candidate vocabularies, use given classes from [voc, coco, ade], or give a custom list of classes',
  58. default='voc',
  59. nargs='+',
  60. )
  61. parser.add_argument(
  62. '--output_folder',
  63. type=str,
  64. help='root of output folder',
  65. )
  66. parser.add_argument(
  67. '--vis',
  68. help='Specify the visualization mode, '
  69. 'could be a list, support input, pred, input_seg, input_pred_seg_label, all_groups, first_group, last_group, mask',
  70. default='input_pred_seg_label',
  71. nargs='+',
  72. )
  73. parser.add_argument(
  74. '--opts',
  75. help="Modify config options by adding 'KEY VALUE' pairs. ",
  76. default=None,
  77. nargs='+',
  78. )
  79. # distributed training
  80. parser.add_argument('--local_rank', type=int, required=False, default=0, help='local rank for DistributedDataParallel')
  81. args = parser.parse_args()
  82. return args
  83. def generate_imagelist_with_sanity_check(root):
  84. image_list = []
  85. for each_file in os.listdir(root):
  86. ### assume we process all .jpg files, and convert png to jpg files
  87. if each_file.endswith('.jpg'):
  88. pass
  89. elif each_file.endswith('.png'):
  90. img = mmcv.imread(osp.join(root, each_file))
  91. mmcv.imwrite(img, osp.join(root, each_file.replace('.png','.jpg')))
  92. else:
  93. continue
  94. filename = each_file.split('.')[0]
  95. image_list.append(filename)
  96. if len(image_list) == 0:
  97. raise ValueError(f'No image found in {args.image_folder}')
  98. with open(os.path.join(root, 'image_list.txt'), 'w') as f:
  99. for item in image_list:
  100. f.write("%s\n" % item)
  101. return image_list
  102. def inference(cfg):
  103. logger = get_logger()
  104. ### check and generate image list ###
  105. generate_imagelist_with_sanity_check(cfg.image_folder)
  106. os.makedirs(cfg.output_folder, exist_ok=True)
  107. data_loader = build_seg_dataloader(build_custom_seg_dataset(cfg.evaluate.seg, cfg))
  108. dataset = data_loader.dataset
  109. print('whether activating visualization: ', cfg.vis)
  110. logger.info(f'Evaluating dataset: {dataset}')
  111. logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
  112. model = build_model(cfg.model)
  113. model.cuda()
  114. logger.info(str(model))
  115. if cfg.train.amp_opt_level != 'O0':
  116. model = amp.initialize(model, None, opt_level=cfg.train.amp_opt_level)
  117. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  118. logger.info(f'number of params: {n_parameters}')
  119. load_checkpoint(cfg, model, None, None)
  120. global tokenizer
  121. tokenizer = tokenizer_dict[cfg.model.text_encoder.type]
  122. if cfg.vis:
  123. vis_seg(cfg, data_loader, model, cfg.vis)
  124. @torch.no_grad()
  125. def vis_seg(config, data_loader, model, vis_modes):
  126. dist.barrier()
  127. model.eval()
  128. if hasattr(model, 'module'):
  129. model_without_ddp = model.module
  130. else:
  131. model_without_ddp = model
  132. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  133. seg_model = build_demo_inference(model_without_ddp, text_transform, config, tokenizer)
  134. mmddp_model = MMDistributedDataParallel(
  135. seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  136. mmddp_model.eval()
  137. model = mmddp_model.module
  138. device = next(model.parameters()).device
  139. dataset = data_loader.dataset
  140. if dist.get_rank() == 0:
  141. prog_bar = mmcv.ProgressBar(len(dataset))
  142. loader_indices = data_loader.batch_sampler
  143. for batch_indices, data in zip(loader_indices, data_loader):
  144. with torch.no_grad():
  145. result = mmddp_model(return_loss=False, **data)
  146. # set_trace()
  147. img_tensor = data['img'][0]
  148. img_metas = data['img_metas'][0].data[0]
  149. imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
  150. assert len(imgs) == len(img_metas)
  151. for batch_idx, img, img_meta in zip(batch_indices, imgs, img_metas):
  152. h, w, _ = img_meta['img_shape']
  153. img_show = img[:h, :w, :]
  154. ori_h, ori_w = img_meta['ori_shape'][:-1]
  155. img_show = mmcv.imresize(img_show, (ori_w, ori_h))
  156. for vis_mode in vis_modes:
  157. out_file = osp.join(config.output_folder, vis_mode, f'{batch_idx:04d}.jpg')
  158. # os.makedirs(osp.join(config.output_folder, 'vis_imgs', vis_mode), exist_ok=True)
  159. print(osp.join(config.output_folder, vis_mode))
  160. model.show_result(img_show, img_tensor.to(device), result, out_file, vis_mode)
  161. if dist.get_rank() == 0:
  162. batch_size = len(result) * dist.get_world_size()
  163. for _ in range(batch_size):
  164. prog_bar.update()
  165. def main():
  166. args = parse_args()
  167. cfg = get_config(args)
  168. if cfg.train.amp_opt_level != 'O0':
  169. assert amp is not None, 'amp not installed!'
  170. with read_write(cfg):
  171. cfg.evaluate.eval_only = True
  172. init_distributed_mode(args)
  173. rank, world_size = args.rank, args.world_size
  174. set_random_seed(cfg.seed, use_rank_shift=True)
  175. cudnn.benchmark = True
  176. os.makedirs(cfg.output, exist_ok=True)
  177. logger = get_logger(cfg)
  178. if dist.get_rank() == 0:
  179. path = os.path.join(cfg.output, 'config.json')
  180. OmegaConf.save(cfg, path)
  181. logger.info(f'Full config saved to {path}')
  182. # print config
  183. logger.info(OmegaConf.to_yaml(cfg))
  184. inference(cfg)
  185. dist.barrier()
  186. if __name__ == '__main__':
  187. main()