base_dataset.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # -------------------------------------------------------------------------
  2. # Written by Jilan Xu
  3. # -------------------------------------------------------------------------
  4. import os
  5. # import linklink as link
  6. import numpy as np
  7. import torch
  8. from torch.utils.data import Dataset
  9. try:
  10. import mc
  11. except ImportError:
  12. pass
  13. # import ceph
  14. # from petrel_client.client import Client
  15. class BaseDataset(Dataset):
  16. def __init__(self,
  17. root_dir,
  18. meta_file,
  19. transform=None,
  20. read_from='mc',
  21. evaluator=None):
  22. super(BaseDataset, self).__init__()
  23. self.root_dir = root_dir
  24. self.meta_file = meta_file
  25. self.transform = transform
  26. self.read_from = read_from
  27. self.evaluator = evaluator
  28. self.initialized = False
  29. if self.read_from == 'petrel':
  30. self._init_petrel()
  31. else:
  32. raise NotImplementedError
  33. def __len__(self):
  34. """
  35. Returns dataset length
  36. """
  37. raise NotImplementedError
  38. def __getitem__(self, idx):
  39. """
  40. Get a single image data: from dataset
  41. Arguments:
  42. - idx (:obj:`int`): index of image, 0 <= idx < len(self)
  43. """
  44. raise NotImplementedError
  45. def _init_petrel(self):
  46. if not self.initialized:
  47. self.client = Client('/mnt/petrelfs/xujilan/petreloss.conf')
  48. self.initialized = True
  49. def read_file(self, meta_dict):
  50. value = self.client.get(meta_dict['filename'])
  51. filebytes = np.frombuffer(value, dtype=np.uint8)
  52. return filebytes
  53. def dump(self, writer, output):
  54. """
  55. Dump classification results
  56. Arguments:
  57. - writer: output stream
  58. - output (:obj:`dict`): different for imagenet and custom
  59. """
  60. raise NotImplementedError
  61. def merge(self, prefix):
  62. """
  63. Merge results into one file.
  64. Arguments:
  65. - prefix (:obj:`str`): dir/results.rank
  66. """
  67. world_size = link.get_world_size()
  68. merged_file = prefix.rsplit('.', 1)[0] + '.all'
  69. merged_fd = open(merged_file, 'w')
  70. for rank in range(world_size):
  71. res_file = prefix + str(rank)
  72. assert os.path.exists(res_file), f'No such file or directory: {res_file}'
  73. with open(res_file, 'r') as fin:
  74. for line_idx, line in enumerate(fin):
  75. merged_fd.write(line)
  76. merged_fd.close()
  77. return merged_file
  78. def inference(self, res_file):
  79. """
  80. Arguments:
  81. - res_file (:obj:`str`): filename of result
  82. """
  83. prefix = res_file.rstrip('0123456789')
  84. merged_res_file = self.merge(prefix)
  85. return merged_res_file
  86. def evaluate(self, res_file):
  87. """
  88. Arguments:
  89. - res_file (:obj:`str`): filename of result
  90. """
  91. prefix = res_file.rstrip('0123456789')
  92. merged_res_file = self.merge(prefix)
  93. metrics = self.evaluator.eval(merged_res_file) if self.evaluator else {}
  94. return metrics
  95. def tensor2numpy(self, x):
  96. if x is None:
  97. return x
  98. if torch.is_tensor(x):
  99. return x.cpu().numpy()
  100. if isinstance(x, list):
  101. x = [_.cpu().numpy() if torch.is_tensor(_) else _ for _ in x]
  102. return x