123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- # -------------------------------------------------------------------------
- # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
- #
- # This work is made available under the Nvidia Source Code License.
- # To view a copy of this license, visit
- # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
- #
- # Written by Jiarui Xu
- # -------------------------------------------------------------------------
- import mmcv
- from mmseg.datasets import build_dataloader, build_dataset
- from mmseg.datasets.pipelines import Compose
- from omegaconf import OmegaConf
- from utils import build_dataset_class_tokens
- from .group_vit_seg import GroupViTSegInference
- def build_seg_dataset(config):
- """Build a dataset from config."""
- cfg = mmcv.Config.fromfile(config.cfg)
- dataset = build_dataset(cfg.data.test)
- return dataset
- def build_seg_dataloader(dataset):
- data_loader = build_dataloader(
- dataset,
- samples_per_gpu=1,
- workers_per_gpu=1,
- dist=True,
- shuffle=False,
- persistent_workers=True,
- pin_memory=False)
- return data_loader
- def build_seg_inference(model, dataset, text_transform, config):
- cfg = mmcv.Config.fromfile(config.cfg)
- if len(config.opts):
- cfg.merge_from_dict(OmegaConf.to_container(OmegaConf.from_dotlist(OmegaConf.to_container(config.opts))))
- with_bg = dataset.CLASSES[0] == 'background'
- if with_bg:
- classnames = dataset.CLASSES[1:]
- else:
- classnames = dataset.CLASSES
- text_tokens = build_dataset_class_tokens(text_transform, config.template, classnames)
- text_embedding = model.build_text_embedding(text_tokens)
- kwargs = dict(with_bg=with_bg)
- if hasattr(cfg, 'test_cfg'):
- kwargs['test_cfg'] = cfg.test_cfg
- seg_model = GroupViTSegInference(model, text_embedding, **kwargs)
- seg_model.CLASSES = dataset.CLASSES
- seg_model.PALETTE = dataset.PALETTE
- return seg_model
- class LoadImage:
- """A simple pipeline to load image."""
- def __call__(self, results):
- """Call function to load images into results.
- Args:
- results (dict): A result dict contains the file name
- of the image to be read.
- Returns:
- dict: ``results`` will be returned containing loaded image.
- """
- if isinstance(results['img'], str):
- results['filename'] = results['img']
- results['ori_filename'] = results['img']
- else:
- results['filename'] = None
- results['ori_filename'] = None
- img = mmcv.imread(results['img'])
- results['img'] = img
- results['img_shape'] = img.shape
- results['ori_shape'] = img.shape
- return results
- def build_seg_demo_pipeline():
- """Build a demo pipeline from config."""
- img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
- test_pipeline = Compose([
- LoadImage(),
- dict(
- type='MultiScaleFlipAug',
- img_scale=(2048, 448),
- flip=False,
- transforms=[
- dict(type='Resize', keep_ratio=True),
- dict(type='RandomFlip'),
- dict(type='Normalize', **img_norm_cfg),
- dict(type='ImageToTensor', keys=['img']),
- dict(type='Collect', keys=['img']),
- ])
- ])
- return test_pipeline
|