main_seg.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # ------------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # ------------------------------------------------------------------------------
  13. # Modified by Jilan Xu
  14. # -------------------------------------------------------------------------
  15. import argparse
  16. import os
  17. import os.path as osp
  18. import subprocess
  19. import mmcv
  20. import torch
  21. import torch.backends.cudnn as cudnn
  22. import torch.distributed as dist
  23. from datasets import build_text_transform
  24. from main_pretrain import validate_seg
  25. from mmcv.image import tensor2imgs
  26. from mmcv.parallel import MMDistributedDataParallel
  27. from mmcv.runner import set_random_seed
  28. from models import build_model
  29. from omegaconf import OmegaConf, read_write
  30. from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
  31. from utils import get_config, get_logger, load_checkpoint
  32. from transformers import AutoTokenizer, RobertaTokenizer
  33. from ipdb import set_trace
  34. from main_pretrain import init_distributed_mode
  35. try:
  36. # noinspection PyUnresolvedReferences
  37. from apex import amp
  38. except ImportError:
  39. amp = None
  40. tokenizer_dict = {
  41. 'Bert': AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False),
  42. # 'Roberta': RobertaTokenizer.from_pretrained('/mnt/petrelfs/xujilan/roberta-base/'),
  43. 'Roberta': RobertaTokenizer.from_pretrained('roberta-base'),
  44. 'TextTransformer': None,
  45. }
  46. def parse_args():
  47. parser = argparse.ArgumentParser('OVSegmentor segmentation evaluation and visualization')
  48. parser.add_argument(
  49. '--cfg',
  50. type=str,
  51. required=True,
  52. help='path to config file',
  53. )
  54. parser.add_argument(
  55. '--opts',
  56. help="Modify config options by adding 'KEY VALUE' pairs. ",
  57. default=None,
  58. nargs='+',
  59. )
  60. parser.add_argument('--resume', help='resume from checkpoint')
  61. parser.add_argument(
  62. '--output', type=str, help='root of output folder, '
  63. 'the full path is <output>/<model_name>/<tag>')
  64. parser.add_argument('--tag', help='tag of experiment')
  65. parser.add_argument(
  66. '--vis',
  67. help='Specify the visualization mode, '
  68. 'could be a list, support input, pred, input_seg, input_pred_seg_label, all_groups, first_group, last_group',
  69. default=None,
  70. nargs='+')
  71. # distributed training
  72. parser.add_argument('--local_rank', type=int, required=False, default=0, help='local rank for DistributedDataParallel')
  73. args = parser.parse_args()
  74. return args
  75. def inference(cfg):
  76. logger = get_logger()
  77. data_loader = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
  78. dataset = data_loader.dataset
  79. print('whether activating visualization: ', cfg.vis)
  80. logger.info(f'Evaluating dataset: {dataset}')
  81. logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
  82. model = build_model(cfg.model)
  83. model.cuda()
  84. logger.info(str(model))
  85. if cfg.train.amp_opt_level != 'O0':
  86. model = amp.initialize(model, None, opt_level=cfg.train.amp_opt_level)
  87. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  88. logger.info(f'number of params: {n_parameters}')
  89. load_checkpoint(cfg, model, None, None)
  90. global tokenizer
  91. tokenizer = tokenizer_dict[cfg.model.text_encoder.type]
  92. if cfg.model.text_encoder.type == 'Roberta':
  93. tokenizer = RobertaTokenizer.from_pretrained('/mnt/petrelfs/xujilan/roberta-base/')
  94. print('Done switching roberta tokenizer')
  95. if 'seg' in cfg.evaluate.task:
  96. miou = validate_seg(cfg, data_loader, model, tokenizer=tokenizer)
  97. logger.info(f'mIoU of the network on the {len(data_loader.dataset)} test images: {miou:.2f}%')
  98. else:
  99. logger.info('No segmentation evaluation specified')
  100. if cfg.vis:
  101. vis_seg(cfg, data_loader, model, cfg.vis)
  102. @torch.no_grad()
  103. def vis_seg(config, data_loader, model, vis_modes):
  104. dist.barrier()
  105. model.eval()
  106. if hasattr(model, 'module'):
  107. model_without_ddp = model.module
  108. else:
  109. model_without_ddp = model
  110. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  111. if config.model.text_encoder['type'] in ['DistilBert', 'Bert','BertMedium','Roberta']:
  112. seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg, tokenizer)
  113. else:
  114. seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
  115. mmddp_model = MMDistributedDataParallel(
  116. seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  117. mmddp_model.eval()
  118. model = mmddp_model.module
  119. device = next(model.parameters()).device
  120. dataset = data_loader.dataset
  121. if dist.get_rank() == 0:
  122. prog_bar = mmcv.ProgressBar(len(dataset))
  123. loader_indices = data_loader.batch_sampler
  124. for batch_indices, data in zip(loader_indices, data_loader):
  125. with torch.no_grad():
  126. result = mmddp_model(return_loss=False, **data)
  127. img_tensor = data['img'][0]
  128. img_metas = data['img_metas'][0].data[0]
  129. imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
  130. assert len(imgs) == len(img_metas)
  131. for batch_idx, img, img_meta in zip(batch_indices, imgs, img_metas):
  132. h, w, _ = img_meta['img_shape']
  133. img_show = img[:h, :w, :]
  134. ori_h, ori_w = img_meta['ori_shape'][:-1]
  135. img_show = mmcv.imresize(img_show, (ori_w, ori_h))
  136. for vis_mode in vis_modes:
  137. out_file = osp.join(config.output, 'vis_imgs', vis_mode, f'{batch_idx:04d}.jpg')
  138. model.show_result(img_show, img_tensor.to(device), result, out_file, vis_mode)
  139. if dist.get_rank() == 0:
  140. batch_size = len(result) * dist.get_world_size()
  141. for _ in range(batch_size):
  142. prog_bar.update()
  143. def main():
  144. args = parse_args()
  145. cfg = get_config(args)
  146. if cfg.train.amp_opt_level != 'O0':
  147. assert amp is not None, 'amp not installed!'
  148. with read_write(cfg):
  149. cfg.evaluate.eval_only = True
  150. init_distributed_mode(args)
  151. rank, world_size = args.rank, args.world_size
  152. set_random_seed(cfg.seed, use_rank_shift=True)
  153. cudnn.benchmark = True
  154. os.makedirs(cfg.output, exist_ok=True)
  155. logger = get_logger(cfg)
  156. if dist.get_rank() == 0:
  157. path = os.path.join(cfg.output, 'config.json')
  158. OmegaConf.save(cfg, path)
  159. logger.info(f'Full config saved to {path}')
  160. # print config
  161. logger.info(OmegaConf.to_yaml(cfg))
  162. inference(cfg)
  163. dist.barrier()
  164. if __name__ == '__main__':
  165. main()