main_seg.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # ------------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # ------------------------------------------------------------------------------
  13. # Modified by Jilan Xu
  14. # -------------------------------------------------------------------------
  15. import argparse
  16. import os
  17. import os.path as osp
  18. import subprocess
  19. import mmcv
  20. import torch
  21. import torch.backends.cudnn as cudnn
  22. import torch.distributed as dist
  23. from datasets import build_text_transform
  24. from main_pretrain import validate_seg
  25. from mmcv.image import tensor2imgs
  26. from mmcv.parallel import MMDistributedDataParallel
  27. from mmcv.runner import set_random_seed
  28. from models import build_model
  29. from omegaconf import OmegaConf, read_write
  30. from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
  31. from utils import get_config, get_logger, load_checkpoint
  32. from transformers import AutoTokenizer, RobertaTokenizer
  33. from ipdb import set_trace
  34. try:
  35. # noinspection PyUnresolvedReferences
  36. from apex import amp
  37. except ImportError:
  38. amp = None
  39. tokenizer_dict = {
  40. 'Bert': AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False),
  41. # 'Roberta': RobertaTokenizer.from_pretrained('/mnt/petrelfs/xujilan/roberta-base/'),
  42. 'Roberta': RobertaTokenizer.from_pretrained('roberta-base'),
  43. 'TextTransformer': None,
  44. }
  45. def parse_args():
  46. parser = argparse.ArgumentParser('GroupViT segmentation evaluation and visualization')
  47. parser.add_argument(
  48. '--cfg',
  49. type=str,
  50. required=True,
  51. help='path to config file',
  52. )
  53. parser.add_argument(
  54. '--opts',
  55. help="Modify config options by adding 'KEY VALUE' pairs. ",
  56. default=None,
  57. nargs='+',
  58. )
  59. parser.add_argument('--resume', help='resume from checkpoint')
  60. parser.add_argument(
  61. '--output', type=str, help='root of output folder, '
  62. 'the full path is <output>/<model_name>/<tag>')
  63. parser.add_argument('--tag', help='tag of experiment')
  64. parser.add_argument(
  65. '--vis',
  66. help='Specify the visualization mode, '
  67. 'could be a list, support input, pred, input_seg, input_pred_seg_label, all_groups, first_group, last_group',
  68. default=None,
  69. nargs='+')
  70. # distributed training
  71. parser.add_argument('--local_rank', type=int, required=False, default=0, help='local rank for DistributedDataParallel')
  72. args = parser.parse_args()
  73. return args
  74. def inference(cfg):
  75. logger = get_logger()
  76. data_loader = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
  77. dataset = data_loader.dataset
  78. print('whether activating visualization: ', cfg.vis)
  79. logger.info(f'Evaluating dataset: {dataset}')
  80. logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
  81. model = build_model(cfg.model)
  82. model.cuda()
  83. logger.info(str(model))
  84. if cfg.train.amp_opt_level != 'O0':
  85. model = amp.initialize(model, None, opt_level=cfg.train.amp_opt_level)
  86. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  87. logger.info(f'number of params: {n_parameters}')
  88. load_checkpoint(cfg, model, None, None)
  89. global tokenizer
  90. tokenizer = tokenizer_dict[cfg.model.text_encoder.type]
  91. if cfg.model.text_encoder.type == 'Roberta':
  92. tokenizer = RobertaTokenizer.from_pretrained('/mnt/petrelfs/xujilan/roberta-base/')
  93. print('Done switching roberta tokenizer')
  94. if 'seg' in cfg.evaluate.task:
  95. miou = validate_seg(cfg, data_loader, model, tokenizer=tokenizer)
  96. logger.info(f'mIoU of the network on the {len(data_loader.dataset)} test images: {miou:.2f}%')
  97. else:
  98. logger.info('No segmentation evaluation specified')
  99. if cfg.vis:
  100. vis_seg(cfg, data_loader, model, cfg.vis)
  101. @torch.no_grad()
  102. def vis_seg(config, data_loader, model, vis_modes):
  103. dist.barrier()
  104. model.eval()
  105. if hasattr(model, 'module'):
  106. model_without_ddp = model.module
  107. else:
  108. model_without_ddp = model
  109. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  110. if config.model.text_encoder['type'] in ['DistilBert', 'Bert','BertMedium','Roberta']:
  111. seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg, tokenizer)
  112. else:
  113. seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
  114. mmddp_model = MMDistributedDataParallel(
  115. seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  116. mmddp_model.eval()
  117. model = mmddp_model.module
  118. device = next(model.parameters()).device
  119. dataset = data_loader.dataset
  120. if dist.get_rank() == 0:
  121. prog_bar = mmcv.ProgressBar(len(dataset))
  122. loader_indices = data_loader.batch_sampler
  123. for batch_indices, data in zip(loader_indices, data_loader):
  124. with torch.no_grad():
  125. result = mmddp_model(return_loss=False, **data)
  126. img_tensor = data['img'][0]
  127. img_metas = data['img_metas'][0].data[0]
  128. imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
  129. assert len(imgs) == len(img_metas)
  130. for batch_idx, img, img_meta in zip(batch_indices, imgs, img_metas):
  131. h, w, _ = img_meta['img_shape']
  132. img_show = img[:h, :w, :]
  133. ori_h, ori_w = img_meta['ori_shape'][:-1]
  134. img_show = mmcv.imresize(img_show, (ori_w, ori_h))
  135. for vis_mode in vis_modes:
  136. out_file = osp.join(config.output, 'vis_imgs', vis_mode, f'{batch_idx:04d}.jpg')
  137. model.show_result(img_show, img_tensor.to(device), result, out_file, vis_mode)
  138. if dist.get_rank() == 0:
  139. batch_size = len(result) * dist.get_world_size()
  140. for _ in range(batch_size):
  141. prog_bar.update()
  142. def setup_for_distributed(is_master):
  143. """
  144. This function disables printing when not in master process
  145. """
  146. import builtins as __builtin__
  147. builtin_print = __builtin__.print
  148. def print(*args, **kwargs):
  149. force = kwargs.pop('force', False)
  150. if is_master or force:
  151. builtin_print(*args, **kwargs)
  152. __builtin__.print = print
  153. def init_distributed_mode(args):
  154. # launched with torch.distributed.launch
  155. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  156. args.rank = int(os.environ["RANK"])
  157. args.world_size = int(os.environ['WORLD_SIZE'])
  158. args.gpu = int(os.environ['LOCAL_RANK'])
  159. # launched with submitit on a slurm cluster
  160. elif 'SLURM_PROCID' in os.environ:
  161. #args.rank = int(os.environ['SLURM_PROCID'])
  162. #args.gpu = args.rank % torch.cuda.device_count()
  163. proc_id = int(os.environ['SLURM_PROCID'])
  164. ntasks = os.environ['SLURM_NTASKS']
  165. node_list = os.environ['SLURM_NODELIST']
  166. num_gpus = torch.cuda.device_count()
  167. addr = subprocess.getoutput(
  168. 'scontrol show hostname {} | head -n1'.format(node_list)
  169. )
  170. master_port = os.environ.get('MASTER_PORT', '29499')
  171. os.environ['MASTER_PORT'] = master_port
  172. os.environ['MASTER_ADDR'] = addr
  173. os.environ['WORLD_SIZE'] = str(ntasks)
  174. os.environ['RANK'] = str(proc_id)
  175. os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
  176. os.environ['LOCAL_SIZE'] = str(num_gpus)
  177. args.dist_url = 'env://'
  178. args.world_size = int(ntasks)
  179. args.rank = int(proc_id)
  180. args.gpu = int(proc_id % num_gpus)
  181. print(f'SLURM MODE: proc_id: {proc_id}, ntasks: {ntasks}, node_list: {node_list}, num_gpus:{num_gpus}, addr:{addr}, master port:{master_port}' )
  182. # launched naively with `python main_dino.py`
  183. # we manually add MASTER_ADDR and MASTER_PORT to env variables
  184. elif torch.cuda.is_available():
  185. print('Will run the code on one GPU.')
  186. args.rank, args.gpu, args.world_size = 0, 0, 1
  187. os.environ['MASTER_ADDR'] = '127.0.0.1'
  188. os.environ['MASTER_PORT'] = '29500'
  189. else:
  190. print('Does not support training without GPU.')
  191. sys.exit(1)
  192. dist.init_process_group(
  193. backend="nccl",
  194. init_method=args.dist_url,
  195. world_size=args.world_size,
  196. rank=args.rank,
  197. )
  198. torch.cuda.set_device(args.gpu)
  199. print('| distributed init (rank {}): {}'.format(
  200. args.rank, args.dist_url), flush=True)
  201. dist.barrier()
  202. setup_for_distributed(args.rank == 0)
  203. def main():
  204. args = parse_args()
  205. cfg = get_config(args)
  206. if cfg.train.amp_opt_level != 'O0':
  207. assert amp is not None, 'amp not installed!'
  208. with read_write(cfg):
  209. cfg.evaluate.eval_only = True
  210. init_distributed_mode(args)
  211. rank, world_size = args.rank, args.world_size
  212. set_random_seed(cfg.seed, use_rank_shift=True)
  213. cudnn.benchmark = True
  214. os.makedirs(cfg.output, exist_ok=True)
  215. logger = get_logger(cfg)
  216. if dist.get_rank() == 0:
  217. path = os.path.join(cfg.output, 'config.json')
  218. OmegaConf.save(cfg, path)
  219. logger.info(f'Full config saved to {path}')
  220. # print config
  221. logger.info(OmegaConf.to_yaml(cfg))
  222. inference(cfg)
  223. dist.barrier()
  224. if __name__ == '__main__':
  225. main()