# ------------------------------------------------------------------------- # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved. # # This work is made available under the Nvidia Source Code License. # To view a copy of this license, visit # https://github.com/NVlabs/GroupViT/blob/main/LICENSE # # Written by Jiarui Xu # ------------------------------------------------------------------------- import mmcv from mmseg.datasets import build_dataloader, build_dataset from mmseg.datasets.pipelines import Compose from omegaconf import OmegaConf from utils import build_dataset_class_tokens from .group_vit_seg import GroupViTSegInference def build_seg_dataset(config): """Build a dataset from config.""" cfg = mmcv.Config.fromfile(config.cfg) dataset = build_dataset(cfg.data.test) return dataset def build_seg_dataloader(dataset): data_loader = build_dataloader( dataset, samples_per_gpu=1, workers_per_gpu=1, dist=True, shuffle=False, persistent_workers=True, pin_memory=False) return data_loader def build_seg_inference(model, dataset, text_transform, config): cfg = mmcv.Config.fromfile(config.cfg) if len(config.opts): cfg.merge_from_dict(OmegaConf.to_container(OmegaConf.from_dotlist(OmegaConf.to_container(config.opts)))) with_bg = dataset.CLASSES[0] == 'background' if with_bg: classnames = dataset.CLASSES[1:] else: classnames = dataset.CLASSES text_tokens = build_dataset_class_tokens(text_transform, config.template, classnames) text_embedding = model.build_text_embedding(text_tokens) kwargs = dict(with_bg=with_bg) if hasattr(cfg, 'test_cfg'): kwargs['test_cfg'] = cfg.test_cfg seg_model = GroupViTSegInference(model, text_embedding, **kwargs) seg_model.CLASSES = dataset.CLASSES seg_model.PALETTE = dataset.PALETTE return seg_model class LoadImage: """A simple pipeline to load image.""" def __call__(self, results): """Call function to load images into results. Args: results (dict): A result dict contains the file name of the image to be read. Returns: dict: ``results`` will be returned containing loaded image. """ if isinstance(results['img'], str): results['filename'] = results['img'] results['ori_filename'] = results['img'] else: results['filename'] = None results['ori_filename'] = None img = mmcv.imread(results['img']) results['img'] = img results['img_shape'] = img.shape results['ori_shape'] = img.shape return results def build_seg_demo_pipeline(): """Build a demo pipeline from config.""" img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) test_pipeline = Compose([ LoadImage(), dict( type='MultiScaleFlipAug', img_scale=(2048, 448), flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ]) return test_pipeline