builder.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. # Modified by Jilan Xu
  14. # -------------------------------------------------------------------------
  15. import mmcv
  16. from mmseg.datasets import build_dataloader, build_dataset
  17. from mmseg.datasets.pipelines import Compose
  18. from omegaconf import OmegaConf
  19. from utils import build_dataset_class_tokens, build_dataset_class_lists
  20. from .group_vit_seg import GroupViTSegInference
  21. from ipdb import set_trace
  22. def build_seg_dataset(config):
  23. """Build a dataset from config."""
  24. cfg = mmcv.Config.fromfile(config.cfg)
  25. dataset = build_dataset(cfg.data.test)
  26. return dataset
  27. def build_custom_seg_dataset(config, args):
  28. """Build a dataset from config."""
  29. cfg = mmcv.Config.fromfile(config.cfg)
  30. cfg.data.test.data_root = args.image_folder
  31. cfg.data.test.img_dir = ''
  32. cfg.data.test.ann_dir = '' ## unsure
  33. cfg.data.test.split = 'image_list.txt'
  34. dataset = build_dataset(cfg.data.test)
  35. return dataset
  36. def build_seg_dataloader(dataset):
  37. data_loader = build_dataloader(
  38. dataset,
  39. samples_per_gpu=1,
  40. workers_per_gpu=1,
  41. dist=True,
  42. shuffle=False,
  43. persistent_workers=True,
  44. pin_memory=False)
  45. return data_loader
  46. def build_seg_inference(model, dataset, text_transform, config, tokenizer=None):
  47. cfg = mmcv.Config.fromfile(config.cfg)
  48. if len(config.opts):
  49. cfg.merge_from_dict(OmegaConf.to_container(OmegaConf.from_dotlist(OmegaConf.to_container(config.opts))))
  50. with_bg = dataset.CLASSES[0] == 'background'
  51. if with_bg:
  52. classnames = dataset.CLASSES[1:]
  53. else:
  54. classnames = dataset.CLASSES
  55. if tokenizer is not None:
  56. text_tokens = build_dataset_class_lists(config.template, classnames)
  57. text_embedding = model.build_text_embedding(text_tokens, tokenizer, num_classes=len(classnames))
  58. else:
  59. text_tokens = build_dataset_class_tokens(text_transform, config.template, classnames)
  60. text_embedding = model.build_text_embedding(text_tokens, num_classes=len(classnames))
  61. kwargs = dict(with_bg=with_bg)
  62. if hasattr(cfg, 'test_cfg'):
  63. kwargs['test_cfg'] = cfg.test_cfg
  64. seg_model = GroupViTSegInference(model, text_embedding, **kwargs)
  65. print('Evaluate during seg inference')
  66. seg_model.CLASSES = dataset.CLASSES
  67. seg_model.PALETTE = dataset.PALETTE
  68. return seg_model
  69. def build_demo_inference(model, text_transform, config, tokenizer=None):
  70. seg_config = config.evaluate.seg
  71. cfg = mmcv.Config.fromfile(seg_config.cfg)
  72. if len(seg_config.opts):
  73. cfg.merge_from_dict(OmegaConf.to_container(OmegaConf.from_dotlist(OmegaConf.to_container(seg_config.opts))))
  74. with_bg = True
  75. from segmentation.datasets.ade20k import ADE20KDataset
  76. from segmentation.datasets.coco_object import COCOObjectDataset
  77. from segmentation.datasets.pascal_voc import PascalVOCDataset
  78. if config.vocab == ['voc']:
  79. classnames = PascalVOCDataset.CLASSES
  80. palette = PascalVOCDataset.PALETTE
  81. elif config.vocab == ['coco']:
  82. classnames = COCOObjectDataset.CLASSES
  83. palette = COCOObjectDataset.PALETTE
  84. elif config.vocab == ['ade']:
  85. classnames = ADE20KDataset.CLASSES
  86. palette = ADE20KDataset.PALETTE
  87. else:
  88. classnames = config.vocab
  89. palette = ADE20KDataset.PALETTE[:len(classnames)]
  90. if classnames[0] == 'background':
  91. classnames = classnames[1:]
  92. print('candidate CLASSES: ', classnames)
  93. print('Using palette: ', palette)
  94. if tokenizer is not None:
  95. text_tokens = build_dataset_class_lists(seg_config.template, classnames)
  96. text_embedding = model.build_text_embedding(text_tokens, tokenizer, num_classes=len(classnames))
  97. else:
  98. text_tokens = build_dataset_class_tokens(text_transform, seg_config.template, classnames)
  99. text_embedding = model.build_text_embedding(text_tokens, num_classes=len(classnames))
  100. kwargs = dict(with_bg=with_bg)
  101. if hasattr(cfg, 'test_cfg'):
  102. kwargs['test_cfg'] = cfg.test_cfg
  103. seg_model = GroupViTSegInference(model, text_embedding, **kwargs)
  104. print('Evaluate during seg inference')
  105. seg_model.CLASSES = tuple(['background'] + list(classnames))
  106. seg_model.PALETTE = palette
  107. return seg_model
  108. class LoadImage:
  109. """A simple pipeline to load image."""
  110. cnt = 0
  111. def __call__(self, results):
  112. """Call function to load images into results.
  113. Args:
  114. results (dict): A result dict contains the file name
  115. of the image to be read.
  116. Returns:
  117. dict: ``results`` will be returned containing loaded image.
  118. """
  119. if isinstance(results['img'], str):
  120. results['filename'] = results['img']
  121. results['ori_filename'] = results['img']
  122. else:
  123. results['filename'] = None
  124. results['ori_filename'] = None
  125. img = mmcv.imread(results['img'])
  126. results['img'] = img
  127. results['img_shape'] = img.shape
  128. results['ori_shape'] = img.shape
  129. return results
  130. def build_seg_demo_pipeline():
  131. """Build a demo pipeline from config."""
  132. img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
  133. test_pipeline = Compose([
  134. LoadImage(),
  135. dict(
  136. type='MultiScaleFlipAug',
  137. img_scale=(2048, 448),
  138. flip=False,
  139. transforms=[
  140. dict(type='Resize', keep_ratio=True),
  141. dict(type='RandomFlip'),
  142. dict(type='Normalize', **img_norm_cfg),
  143. dict(type='ImageToTensor', keys=['img']),
  144. dict(type='Collect', keys=['img']),
  145. ])
  146. ])
  147. return test_pipeline