main_seg.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # ------------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # ------------------------------------------------------------------------------
  13. # Modified by Jilan Xu
  14. # -------------------------------------------------------------------------
  15. import argparse
  16. import os
  17. import os.path as osp
  18. import subprocess
  19. import mmcv
  20. import torch
  21. import torch.backends.cudnn as cudnn
  22. import torch.distributed as dist
  23. from datasets import build_text_transform
  24. from main_group_vit import validate_seg
  25. from mmcv.image import tensor2imgs
  26. from mmcv.parallel import MMDistributedDataParallel
  27. from mmcv.runner import set_random_seed
  28. from models import build_model
  29. from omegaconf import OmegaConf, read_write
  30. from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
  31. from utils import get_config, get_logger, load_checkpoint
  32. from transformers import AutoTokenizer, RobertaTokenizer
  33. from ipdb import set_trace
  34. try:
  35. # noinspection PyUnresolvedReferences
  36. from apex import amp
  37. except ImportError:
  38. amp = None
  39. tokenizer_dict = {
  40. 'Bert': AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False),
  41. 'Roberta': RobertaTokenizer.from_pretrained('/mnt/petrelfs/xujilan/roberta-base/'),
  42. 'TextTransformer': None,
  43. }
  44. def parse_args():
  45. parser = argparse.ArgumentParser('GroupViT segmentation evaluation and visualization')
  46. parser.add_argument(
  47. '--cfg',
  48. type=str,
  49. required=True,
  50. help='path to config file',
  51. )
  52. parser.add_argument(
  53. '--opts',
  54. help="Modify config options by adding 'KEY VALUE' pairs. ",
  55. default=None,
  56. nargs='+',
  57. )
  58. parser.add_argument('--resume', help='resume from checkpoint')
  59. parser.add_argument(
  60. '--output', type=str, help='root of output folder, '
  61. 'the full path is <output>/<model_name>/<tag>')
  62. parser.add_argument('--tag', help='tag of experiment')
  63. parser.add_argument(
  64. '--vis',
  65. help='Specify the visualization mode, '
  66. 'could be a list, support input, pred, input_seg, input_pred_seg_label, all_groups, first_group, last_group',
  67. default=None,
  68. nargs='+')
  69. # distributed training
  70. parser.add_argument('--local_rank', type=int, required=False, default=0, help='local rank for DistributedDataParallel')
  71. args = parser.parse_args()
  72. return args
  73. def inference(cfg):
  74. logger = get_logger()
  75. data_loader = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
  76. dataset = data_loader.dataset
  77. print('whether activating visualization: ', cfg.vis)
  78. logger.info(f'Evaluating dataset: {dataset}')
  79. logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
  80. model = build_model(cfg.model)
  81. model.cuda()
  82. logger.info(str(model))
  83. if cfg.train.amp_opt_level != 'O0':
  84. model = amp.initialize(model, None, opt_level=cfg.train.amp_opt_level)
  85. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  86. logger.info(f'number of params: {n_parameters}')
  87. load_checkpoint(cfg, model, None, None)
  88. global tokenizer
  89. tokenizer = tokenizer_dict[cfg.model.text_encoder.type]
  90. if cfg.model.text_encoder.type == 'Roberta':
  91. tokenizer = RobertaTokenizer.from_pretrained('/mnt/petrelfs/xujilan/roberta-base/')
  92. print('Done switching roberta tokenizer')
  93. if 'seg' in cfg.evaluate.task:
  94. miou = validate_seg(cfg, data_loader, model, tokenizer=tokenizer)
  95. logger.info(f'mIoU of the network on the {len(data_loader.dataset)} test images: {miou:.2f}%')
  96. else:
  97. logger.info('No segmentation evaluation specified')
  98. if cfg.vis:
  99. vis_seg(cfg, data_loader, model, cfg.vis)
  100. @torch.no_grad()
  101. def vis_seg(config, data_loader, model, vis_modes):
  102. dist.barrier()
  103. model.eval()
  104. if hasattr(model, 'module'):
  105. model_without_ddp = model.module
  106. else:
  107. model_without_ddp = model
  108. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  109. if config.model.text_encoder['type'] in ['DistilBert', 'Bert','BertMedium','Roberta']:
  110. seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg, tokenizer)
  111. else:
  112. seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
  113. mmddp_model = MMDistributedDataParallel(
  114. seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  115. mmddp_model.eval()
  116. model = mmddp_model.module
  117. device = next(model.parameters()).device
  118. dataset = data_loader.dataset
  119. if dist.get_rank() == 0:
  120. prog_bar = mmcv.ProgressBar(len(dataset))
  121. loader_indices = data_loader.batch_sampler
  122. for batch_indices, data in zip(loader_indices, data_loader):
  123. with torch.no_grad():
  124. result = mmddp_model(return_loss=False, **data)
  125. img_tensor = data['img'][0]
  126. img_metas = data['img_metas'][0].data[0]
  127. imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
  128. assert len(imgs) == len(img_metas)
  129. for batch_idx, img, img_meta in zip(batch_indices, imgs, img_metas):
  130. h, w, _ = img_meta['img_shape']
  131. img_show = img[:h, :w, :]
  132. ori_h, ori_w = img_meta['ori_shape'][:-1]
  133. img_show = mmcv.imresize(img_show, (ori_w, ori_h))
  134. for vis_mode in vis_modes:
  135. out_file = osp.join(config.output, 'vis_imgs', vis_mode, f'{batch_idx:04d}.jpg')
  136. model.show_result(img_show, img_tensor.to(device), result, out_file, vis_mode)
  137. if dist.get_rank() == 0:
  138. batch_size = len(result) * dist.get_world_size()
  139. for _ in range(batch_size):
  140. prog_bar.update()
  141. def setup_for_distributed(is_master):
  142. """
  143. This function disables printing when not in master process
  144. """
  145. import builtins as __builtin__
  146. builtin_print = __builtin__.print
  147. def print(*args, **kwargs):
  148. force = kwargs.pop('force', False)
  149. if is_master or force:
  150. builtin_print(*args, **kwargs)
  151. __builtin__.print = print
  152. def init_distributed_mode(args):
  153. # launched with torch.distributed.launch
  154. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  155. args.rank = int(os.environ["RANK"])
  156. args.world_size = int(os.environ['WORLD_SIZE'])
  157. args.gpu = int(os.environ['LOCAL_RANK'])
  158. # launched with submitit on a slurm cluster
  159. elif 'SLURM_PROCID' in os.environ:
  160. #args.rank = int(os.environ['SLURM_PROCID'])
  161. #args.gpu = args.rank % torch.cuda.device_count()
  162. proc_id = int(os.environ['SLURM_PROCID'])
  163. ntasks = os.environ['SLURM_NTASKS']
  164. node_list = os.environ['SLURM_NODELIST']
  165. num_gpus = torch.cuda.device_count()
  166. addr = subprocess.getoutput(
  167. 'scontrol show hostname {} | head -n1'.format(node_list)
  168. )
  169. master_port = os.environ.get('MASTER_PORT', '29499')
  170. os.environ['MASTER_PORT'] = master_port
  171. os.environ['MASTER_ADDR'] = addr
  172. os.environ['WORLD_SIZE'] = str(ntasks)
  173. os.environ['RANK'] = str(proc_id)
  174. os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
  175. os.environ['LOCAL_SIZE'] = str(num_gpus)
  176. args.dist_url = 'env://'
  177. args.world_size = int(ntasks)
  178. args.rank = int(proc_id)
  179. args.gpu = int(proc_id % num_gpus)
  180. print(f'SLURM MODE: proc_id: {proc_id}, ntasks: {ntasks}, node_list: {node_list}, num_gpus:{num_gpus}, addr:{addr}, master port:{master_port}' )
  181. # launched naively with `python main_dino.py`
  182. # we manually add MASTER_ADDR and MASTER_PORT to env variables
  183. elif torch.cuda.is_available():
  184. print('Will run the code on one GPU.')
  185. args.rank, args.gpu, args.world_size = 0, 0, 1
  186. os.environ['MASTER_ADDR'] = '127.0.0.1'
  187. os.environ['MASTER_PORT'] = '29500'
  188. else:
  189. print('Does not support training without GPU.')
  190. sys.exit(1)
  191. dist.init_process_group(
  192. backend="nccl",
  193. init_method=args.dist_url,
  194. world_size=args.world_size,
  195. rank=args.rank,
  196. )
  197. torch.cuda.set_device(args.gpu)
  198. print('| distributed init (rank {}): {}'.format(
  199. args.rank, args.dist_url), flush=True)
  200. dist.barrier()
  201. setup_for_distributed(args.rank == 0)
  202. def main():
  203. args = parse_args()
  204. cfg = get_config(args)
  205. if cfg.train.amp_opt_level != 'O0':
  206. assert amp is not None, 'amp not installed!'
  207. with read_write(cfg):
  208. cfg.evaluate.eval_only = True
  209. init_distributed_mode(args)
  210. rank, world_size = args.rank, args.world_size
  211. set_random_seed(cfg.seed, use_rank_shift=True)
  212. cudnn.benchmark = True
  213. os.makedirs(cfg.output, exist_ok=True)
  214. logger = get_logger(cfg)
  215. if dist.get_rank() == 0:
  216. path = os.path.join(cfg.output, 'config.json')
  217. OmegaConf.save(cfg, path)
  218. logger.info(f'Full config saved to {path}')
  219. # print config
  220. logger.info(OmegaConf.to_yaml(cfg))
  221. inference(cfg)
  222. dist.barrier()
  223. if __name__ == '__main__':
  224. main()