builder.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # -------------------------------------------------------------------------
  10. import mmcv
  11. from mmseg.datasets import build_dataloader, build_dataset
  12. from mmseg.datasets.pipelines import Compose
  13. from omegaconf import OmegaConf
  14. from utils import build_dataset_class_tokens
  15. from .group_vit_seg import GroupViTSegInference
  16. def build_seg_dataset(config):
  17. """Build a dataset from config."""
  18. cfg = mmcv.Config.fromfile(config.cfg)
  19. dataset = build_dataset(cfg.data.test)
  20. return dataset
  21. def build_seg_dataloader(dataset):
  22. data_loader = build_dataloader(
  23. dataset,
  24. samples_per_gpu=1,
  25. workers_per_gpu=1,
  26. dist=True,
  27. shuffle=False,
  28. persistent_workers=True,
  29. pin_memory=False)
  30. return data_loader
  31. def build_seg_inference(model, dataset, text_transform, config):
  32. cfg = mmcv.Config.fromfile(config.cfg)
  33. if len(config.opts):
  34. cfg.merge_from_dict(OmegaConf.to_container(OmegaConf.from_dotlist(OmegaConf.to_container(config.opts))))
  35. with_bg = dataset.CLASSES[0] == 'background'
  36. if with_bg:
  37. classnames = dataset.CLASSES[1:]
  38. else:
  39. classnames = dataset.CLASSES
  40. text_tokens = build_dataset_class_tokens(text_transform, config.template, classnames)
  41. text_embedding = model.build_text_embedding(text_tokens)
  42. kwargs = dict(with_bg=with_bg)
  43. if hasattr(cfg, 'test_cfg'):
  44. kwargs['test_cfg'] = cfg.test_cfg
  45. seg_model = GroupViTSegInference(model, text_embedding, **kwargs)
  46. seg_model.CLASSES = dataset.CLASSES
  47. seg_model.PALETTE = dataset.PALETTE
  48. return seg_model
  49. class LoadImage:
  50. """A simple pipeline to load image."""
  51. def __call__(self, results):
  52. """Call function to load images into results.
  53. Args:
  54. results (dict): A result dict contains the file name
  55. of the image to be read.
  56. Returns:
  57. dict: ``results`` will be returned containing loaded image.
  58. """
  59. if isinstance(results['img'], str):
  60. results['filename'] = results['img']
  61. results['ori_filename'] = results['img']
  62. else:
  63. results['filename'] = None
  64. results['ori_filename'] = None
  65. img = mmcv.imread(results['img'])
  66. results['img'] = img
  67. results['img_shape'] = img.shape
  68. results['ori_shape'] = img.shape
  69. return results
  70. def build_seg_demo_pipeline():
  71. """Build a demo pipeline from config."""
  72. img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
  73. test_pipeline = Compose([
  74. LoadImage(),
  75. dict(
  76. type='MultiScaleFlipAug',
  77. img_scale=(2048, 448),
  78. flip=False,
  79. transforms=[
  80. dict(type='Resize', keep_ratio=True),
  81. dict(type='RandomFlip'),
  82. dict(type='Normalize', **img_norm_cfg),
  83. dict(type='ImageToTensor', keys=['img']),
  84. dict(type='Collect', keys=['img']),
  85. ])
  86. ])
  87. return test_pipeline