builder.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. # Modified by Jilan Xu
  14. # -------------------------------------------------------------------------
  15. import mmcv
  16. from mmseg.datasets import build_dataloader, build_dataset
  17. from mmseg.datasets.pipelines import Compose
  18. from omegaconf import OmegaConf
  19. from utils import build_dataset_class_tokens, build_dataset_class_lists
  20. from .group_vit_seg import GroupViTSegInference
  21. from ipdb import set_trace
  22. def build_seg_dataset(config):
  23. """Build a dataset from config."""
  24. cfg = mmcv.Config.fromfile(config.cfg)
  25. dataset = build_dataset(cfg.data.test)
  26. return dataset
  27. def build_seg_dataloader(dataset):
  28. data_loader = build_dataloader(
  29. dataset,
  30. samples_per_gpu=1,
  31. workers_per_gpu=1,
  32. dist=True,
  33. shuffle=False,
  34. persistent_workers=True,
  35. pin_memory=False)
  36. return data_loader
  37. def build_seg_inference(model, dataset, text_transform, config, tokenizer=None):
  38. cfg = mmcv.Config.fromfile(config.cfg)
  39. if len(config.opts):
  40. cfg.merge_from_dict(OmegaConf.to_container(OmegaConf.from_dotlist(OmegaConf.to_container(config.opts))))
  41. with_bg = dataset.CLASSES[0] == 'background'
  42. if with_bg:
  43. classnames = dataset.CLASSES[1:]
  44. else:
  45. classnames = dataset.CLASSES
  46. if tokenizer is not None:
  47. text_tokens = build_dataset_class_lists(config.template, classnames)
  48. text_embedding = model.build_text_embedding(text_tokens, tokenizer, num_classes=len(classnames))
  49. else:
  50. text_tokens = build_dataset_class_tokens(text_transform, config.template, classnames)
  51. text_embedding = model.build_text_embedding(text_tokens, num_classes=len(classnames))
  52. kwargs = dict(with_bg=with_bg)
  53. if hasattr(cfg, 'test_cfg'):
  54. kwargs['test_cfg'] = cfg.test_cfg
  55. seg_model = GroupViTSegInference(model, text_embedding, **kwargs)
  56. print('Evaluate GroupViT during seg inference')
  57. seg_model.CLASSES = dataset.CLASSES
  58. seg_model.PALETTE = dataset.PALETTE
  59. return seg_model
  60. class LoadImage:
  61. """A simple pipeline to load image."""
  62. cnt = 0
  63. def __call__(self, results):
  64. """Call function to load images into results.
  65. Args:
  66. results (dict): A result dict contains the file name
  67. of the image to be read.
  68. Returns:
  69. dict: ``results`` will be returned containing loaded image.
  70. """
  71. if isinstance(results['img'], str):
  72. results['filename'] = results['img']
  73. results['ori_filename'] = results['img']
  74. else:
  75. results['filename'] = None
  76. results['ori_filename'] = None
  77. img = mmcv.imread(results['img'])
  78. results['img'] = img
  79. results['img_shape'] = img.shape
  80. results['ori_shape'] = img.shape
  81. return results
  82. def build_seg_demo_pipeline():
  83. """Build a demo pipeline from config."""
  84. img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
  85. test_pipeline = Compose([
  86. LoadImage(),
  87. dict(
  88. type='MultiScaleFlipAug',
  89. img_scale=(2048, 448),
  90. flip=False,
  91. transforms=[
  92. dict(type='Resize', keep_ratio=True),
  93. dict(type='RandomFlip'),
  94. dict(type='Normalize', **img_norm_cfg),
  95. dict(type='ImageToTensor', keys=['img']),
  96. dict(type='Collect', keys=['img']),
  97. ])
  98. ])
  99. return test_pipeline