builder.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. import mmcv
  14. from mmseg.datasets import build_dataloader, build_dataset
  15. from mmseg.datasets.pipelines import Compose
  16. from omegaconf import OmegaConf
  17. from utils import build_dataset_class_tokens
  18. from .group_vit_seg import GroupViTSegInference
  19. def build_seg_dataset(config):
  20. """Build a dataset from config."""
  21. cfg = mmcv.Config.fromfile(config.cfg)
  22. dataset = build_dataset(cfg.data.test)
  23. return dataset
  24. def build_seg_dataloader(dataset):
  25. data_loader = build_dataloader(
  26. dataset,
  27. samples_per_gpu=1,
  28. workers_per_gpu=1,
  29. dist=True,
  30. shuffle=False,
  31. persistent_workers=True,
  32. pin_memory=False)
  33. return data_loader
  34. def build_seg_inference(model, dataset, text_transform, config):
  35. cfg = mmcv.Config.fromfile(config.cfg)
  36. if len(config.opts):
  37. cfg.merge_from_dict(OmegaConf.to_container(OmegaConf.from_dotlist(OmegaConf.to_container(config.opts))))
  38. with_bg = dataset.CLASSES[0] == 'background'
  39. if with_bg:
  40. classnames = dataset.CLASSES[1:]
  41. else:
  42. classnames = dataset.CLASSES
  43. text_tokens = build_dataset_class_tokens(text_transform, config.template, classnames)
  44. text_embedding = model.build_text_embedding(text_tokens)
  45. kwargs = dict(with_bg=with_bg)
  46. if hasattr(cfg, 'test_cfg'):
  47. kwargs['test_cfg'] = cfg.test_cfg
  48. seg_model = GroupViTSegInference(model, text_embedding, **kwargs)
  49. seg_model.CLASSES = dataset.CLASSES
  50. seg_model.PALETTE = dataset.PALETTE
  51. return seg_model
  52. class LoadImage:
  53. """A simple pipeline to load image."""
  54. def __call__(self, results):
  55. """Call function to load images into results.
  56. Args:
  57. results (dict): A result dict contains the file name
  58. of the image to be read.
  59. Returns:
  60. dict: ``results`` will be returned containing loaded image.
  61. """
  62. if isinstance(results['img'], str):
  63. results['filename'] = results['img']
  64. results['ori_filename'] = results['img']
  65. else:
  66. results['filename'] = None
  67. results['ori_filename'] = None
  68. img = mmcv.imread(results['img'])
  69. results['img'] = img
  70. results['img_shape'] = img.shape
  71. results['ori_shape'] = img.shape
  72. return results
  73. def build_seg_demo_pipeline():
  74. """Build a demo pipeline from config."""
  75. img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
  76. test_pipeline = Compose([
  77. LoadImage(),
  78. dict(
  79. type='MultiScaleFlipAug',
  80. img_scale=(2048, 448),
  81. flip=False,
  82. transforms=[
  83. dict(type='Resize', keep_ratio=True),
  84. dict(type='RandomFlip'),
  85. dict(type='Normalize', **img_norm_cfg),
  86. dict(type='ImageToTensor', keys=['img']),
  87. dict(type='Collect', keys=['img']),
  88. ])
  89. ])
  90. return test_pipeline