123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- # -------------------------------------------------------------------------
- # Written by Jilan Xu
- # -------------------------------------------------------------------------
- import os
- # import linklink as link
- import numpy as np
- import torch
- from torch.utils.data import Dataset
- try:
- import mc
- except ImportError:
- pass
- # import ceph
- # from petrel_client.client import Client
- class BaseDataset(Dataset):
- def __init__(self,
- root_dir,
- meta_file,
- transform=None,
- read_from='mc',
- evaluator=None):
- super(BaseDataset, self).__init__()
- self.root_dir = root_dir
- self.meta_file = meta_file
- self.transform = transform
- self.read_from = read_from
- self.evaluator = evaluator
- self.initialized = False
- if self.read_from == 'petrel':
- self._init_petrel()
- else:
- raise NotImplementedError
- def __len__(self):
- """
- Returns dataset length
- """
- raise NotImplementedError
- def __getitem__(self, idx):
- """
- Get a single image data: from dataset
- Arguments:
- - idx (:obj:`int`): index of image, 0 <= idx < len(self)
- """
- raise NotImplementedError
- def _init_petrel(self):
- if not self.initialized:
- self.client = Client('/mnt/petrelfs/xujilan/petreloss.conf')
- self.initialized = True
-
- def read_file(self, meta_dict):
- value = self.client.get(meta_dict['filename'])
- filebytes = np.frombuffer(value, dtype=np.uint8)
- return filebytes
- def dump(self, writer, output):
- """
- Dump classification results
- Arguments:
- - writer: output stream
- - output (:obj:`dict`): different for imagenet and custom
- """
- raise NotImplementedError
- def merge(self, prefix):
- """
- Merge results into one file.
- Arguments:
- - prefix (:obj:`str`): dir/results.rank
- """
- world_size = link.get_world_size()
- merged_file = prefix.rsplit('.', 1)[0] + '.all'
- merged_fd = open(merged_file, 'w')
- for rank in range(world_size):
- res_file = prefix + str(rank)
- assert os.path.exists(res_file), f'No such file or directory: {res_file}'
- with open(res_file, 'r') as fin:
- for line_idx, line in enumerate(fin):
- merged_fd.write(line)
- merged_fd.close()
- return merged_file
- def inference(self, res_file):
- """
- Arguments:
- - res_file (:obj:`str`): filename of result
- """
- prefix = res_file.rstrip('0123456789')
- merged_res_file = self.merge(prefix)
- return merged_res_file
- def evaluate(self, res_file):
- """
- Arguments:
- - res_file (:obj:`str`): filename of result
- """
- prefix = res_file.rstrip('0123456789')
- merged_res_file = self.merge(prefix)
- metrics = self.evaluator.eval(merged_res_file) if self.evaluator else {}
- return metrics
- def tensor2numpy(self, x):
- if x is None:
- return x
- if torch.is_tensor(x):
- return x.cpu().numpy()
- if isinstance(x, list):
- x = [_.cpu().numpy() if torch.is_tensor(_) else _ for _ in x]
- return x
|