base_dataset.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # -------------------------------------------------------------------------
  2. # Written by Jilan Xu
  3. # -------------------------------------------------------------------------
  4. import os
  5. # import linklink as link
  6. import numpy as np
  7. import torch
  8. from torch.utils.data import Dataset
  9. try:
  10. import mc
  11. except ImportError:
  12. pass
  13. # import ceph
  14. # from petrel_client.client import Client
  15. class BaseDataset(Dataset):
  16. def __init__(self,
  17. root_dir,
  18. meta_file,
  19. transform=None,
  20. read_from='mc',
  21. evaluator=None):
  22. super(BaseDataset, self).__init__()
  23. self.root_dir = root_dir
  24. self.meta_file = meta_file
  25. self.transform = transform
  26. self.read_from = read_from
  27. self.evaluator = evaluator
  28. self.initialized = False
  29. print('READ from:', self.read_from)
  30. if self.read_from == 'petrel':
  31. self._init_petrel()
  32. elif self.read_from == 'dir':
  33. pass
  34. else:
  35. raise NotImplementedError
  36. def __len__(self):
  37. """
  38. Returns dataset length
  39. """
  40. raise NotImplementedError
  41. def __getitem__(self, idx):
  42. """
  43. Get a single image data: from dataset
  44. Arguments:
  45. - idx (:obj:`int`): index of image, 0 <= idx < len(self)
  46. """
  47. raise NotImplementedError
  48. def _init_petrel(self):
  49. from petrel_client.client import Client
  50. if not self.initialized:
  51. self.client = Client()
  52. self.initialized = True
  53. def read_file(self, meta_dict):
  54. value = self.client.get(meta_dict['filename'])
  55. filebytes = np.frombuffer(value, dtype=np.uint8)
  56. return filebytes
  57. def dump(self, writer, output):
  58. """
  59. Dump classification results
  60. Arguments:
  61. - writer: output stream
  62. - output (:obj:`dict`): different for imagenet and custom
  63. """
  64. raise NotImplementedError
  65. def merge(self, prefix):
  66. """
  67. Merge results into one file.
  68. Arguments:
  69. - prefix (:obj:`str`): dir/results.rank
  70. """
  71. world_size = link.get_world_size()
  72. merged_file = prefix.rsplit('.', 1)[0] + '.all'
  73. merged_fd = open(merged_file, 'w')
  74. for rank in range(world_size):
  75. res_file = prefix + str(rank)
  76. assert os.path.exists(res_file), f'No such file or directory: {res_file}'
  77. with open(res_file, 'r') as fin:
  78. for line_idx, line in enumerate(fin):
  79. merged_fd.write(line)
  80. merged_fd.close()
  81. return merged_file
  82. def inference(self, res_file):
  83. """
  84. Arguments:
  85. - res_file (:obj:`str`): filename of result
  86. """
  87. prefix = res_file.rstrip('0123456789')
  88. merged_res_file = self.merge(prefix)
  89. return merged_res_file
  90. def evaluate(self, res_file):
  91. """
  92. Arguments:
  93. - res_file (:obj:`str`): filename of result
  94. """
  95. prefix = res_file.rstrip('0123456789')
  96. merged_res_file = self.merge(prefix)
  97. metrics = self.evaluator.eval(merged_res_file) if self.evaluator else {}
  98. return metrics
  99. def tensor2numpy(self, x):
  100. if x is None:
  101. return x
  102. if torch.is_tensor(x):
  103. return x.cpu().numpy()
  104. if isinstance(x, list):
  105. x = [_.cpu().numpy() if torch.is_tensor(_) else _ for _ in x]
  106. return x