sampler.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # -------------------------------------------------------------------------
  2. # Written by Jilan Xu
  3. # -------------------------------------------------------------------------
  4. import torch
  5. from torch.utils.data.sampler import Sampler
  6. import linklink as link
  7. import math
  8. import numpy as np
  9. class DistributedSampler(Sampler):
  10. def __init__(self, dataset, world_size=None, rank=None, round_up=True):
  11. if world_size is None:
  12. world_size = link.get_world_size()
  13. if rank is None:
  14. rank = link.get_rank()
  15. self.dataset = dataset
  16. self.world_size = world_size
  17. self.rank = rank
  18. self.round_up = round_up
  19. self.epoch = 0
  20. self.num_samples = int(
  21. math.ceil(len(self.dataset) * 1.0 / self.world_size))
  22. if self.round_up:
  23. self.total_size = self.num_samples * self.world_size
  24. self.length = self.num_samples
  25. else:
  26. self.total_size = len(self.dataset)
  27. if self.rank < self.world_size-1:
  28. self.length = self.num_samples
  29. else:
  30. self.length = self.total_size - \
  31. (self.world_size-1)*self.num_samples
  32. def __iter__(self):
  33. g = torch.Generator()
  34. g.manual_seed(self.epoch)
  35. indices = list(torch.randperm(len(self.dataset), generator=g))
  36. if self.round_up:
  37. indices += indices[:(self.total_size - len(indices))]
  38. assert len(indices) == self.total_size
  39. offset = self.num_samples * self.rank
  40. indices = indices[offset:offset + self.num_samples]
  41. if self.round_up or (not self.round_up and self.rank < self.world_size-1):
  42. assert len(indices) == self.num_samples
  43. return iter(indices)
  44. def __len__(self):
  45. return self.length
  46. def set_epoch(self, epoch):
  47. self.epoch = epoch
  48. class DistributedGivenIterationSampler(Sampler):
  49. def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=0):
  50. if world_size is None:
  51. world_size = link.get_world_size()
  52. if rank is None:
  53. rank = link.get_rank()
  54. assert rank < world_size
  55. self.dataset = dataset
  56. self.total_iter = total_iter
  57. self.batch_size = batch_size
  58. self.world_size = world_size
  59. self.rank = rank
  60. self.last_iter = last_iter
  61. self.total_size = self.total_iter*self.batch_size
  62. self.indices = self.gen_new_list()
  63. self.call = 0
  64. def __iter__(self):
  65. if self.call == 0:
  66. self.call = 1
  67. return iter(self.indices[self.last_iter*self.batch_size:])
  68. else:
  69. raise RuntimeError(
  70. "this sampler is not designed to be called more than once!!")
  71. def gen_new_list(self):
  72. np.random.seed(0)
  73. all_size = self.total_size * self.world_size
  74. indices = np.arange(len(self.dataset))
  75. indices = indices[:all_size]
  76. num_repeat = (all_size-1) // indices.shape[0] + 1
  77. indices = np.tile(indices, num_repeat)
  78. indices = indices[:all_size]
  79. np.random.shuffle(indices)
  80. beg = self.total_size * self.rank
  81. indices = indices[beg:beg+self.total_size]
  82. assert len(indices) == self.total_size
  83. return indices
  84. def __len__(self):
  85. # note here we do not take last iter into consideration, since __len__
  86. # should only be used for displaying, the correct remaining size is
  87. # handled by dataloader
  88. return self.total_size
  89. class DistributedEpochSampler(Sampler):
  90. def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=0):
  91. if world_size is None:
  92. world_size = link.get_world_size()
  93. if rank is None:
  94. rank = link.get_rank()
  95. assert rank < world_size
  96. self.dataset = dataset
  97. self.total_iter = total_iter
  98. self.batch_size = batch_size
  99. self.world_size = world_size
  100. self.rank = rank
  101. self.last_iter = last_iter
  102. self.all_size_single = self.total_iter * self.batch_size
  103. self.indices = self.gen_new_list()
  104. self.call = 0
  105. def __iter__(self):
  106. if self.call == 0:
  107. self.call = 1
  108. return iter(self.indices[self.last_iter*self.batch_size:])
  109. else:
  110. raise RuntimeError(
  111. "this sampler is not designed to be called more than once!!")
  112. def get_one_epoch_self_part(self):
  113. num = len(self.dataset)
  114. indices = np.arange(num)
  115. extra_indices = np.random.choice(
  116. num, self.extra_per_epoch, replace=False)
  117. indices = np.concatenate((indices, extra_indices))
  118. np.random.shuffle(indices)
  119. assert len(indices) % (self.world_size * self.batch_size) == 0
  120. num_single = len(indices) // self.world_size
  121. return indices[self.rank*num_single:(self.rank+1)*num_single]
  122. def gen_new_list(self):
  123. np.random.seed(0)
  124. self.all_num = self.total_iter * self.batch_size * self.world_size
  125. iter_per_epoch = (len(self.dataset) -
  126. 1) // (self.batch_size * self.world_size) + 1
  127. self.num_per_epoch = iter_per_epoch * self.batch_size * self.world_size
  128. self.extra_per_epoch = self.num_per_epoch - len(self.dataset)
  129. repeat = (self.all_num - 1) // self.num_per_epoch + 1
  130. indices = []
  131. for i in range(repeat):
  132. indice = self.get_one_epoch_self_part()
  133. indices.append(indice)
  134. indices = np.concatenate(indices)
  135. indices = indices[:self.all_size_single]
  136. assert len(indices) == self.all_size_single
  137. return indices
  138. def __len__(self):
  139. return self.all_size_single
  140. class RankedGivenIterationSampler(Sampler):
  141. def __init__(self, dataset, total_iter, batch_size, last_iter=0):
  142. self.dataset = dataset
  143. self.total_iter = total_iter
  144. self.batch_size = batch_size
  145. self.last_iter = last_iter
  146. self.total_size = self.total_iter*self.batch_size
  147. self.cur_size = self.last_iter * self.batch_size
  148. # self.indices = self.gen_new_list()
  149. self.indices = np.arange(len(self.dataset))
  150. self.call = 0
  151. def indice_generator(self):
  152. np.random.shuffle(self.indices)
  153. while self.cur_size < self.total_size:
  154. #np.random.shuffle(self.indices)
  155. remaining_size = self.total_size - self.cur_size
  156. indices = self.indices[:remaining_size]
  157. self.cur_size += len(indices)
  158. for item in indices:
  159. yield item
  160. def __iter__(self):
  161. if self.call == 0:
  162. self.call = 1
  163. return self.indice_generator()
  164. else:
  165. raise RuntimeError("this sampler is not designed to be called more than once!!")
  166. def __len__(self):
  167. # note here we do not take last iter into consideration, since __len__
  168. # should only be used for displaying, the correct remaining size is
  169. # handled by dataloader
  170. return self.total_size
  171. sampler_dict = {
  172. 'distributed': DistributedSampler,
  173. 'distributed_iteration': DistributedGivenIterationSampler,
  174. 'distributed_epoch': DistributedEpochSampler,
  175. 'ranked_iteration': RankedGivenIterationSampler
  176. }
  177. def build_sampler(dataset, cfg_sampler, cfg_dataset):
  178. batch_size = cfg_dataset['batch_size']
  179. # check step type: iteration or epoch ?
  180. if not getattr(cfg_dataset, 'max_iter', False):
  181. world_size = link.get_world_size()
  182. iter_per_epoch = (len(dataset) - 1) // (batch_size * world_size) + 1
  183. if cfg_sampler['type'] == "naive":
  184. total_iter = cfg_dataset['max_epoch'] * ((len(dataset) - 1) // batch_size + 1) #125200
  185. else:
  186. total_iter = cfg_dataset['max_epoch'] * iter_per_epoch
  187. else:
  188. total_iter = cfg_dataset['max_iter']
  189. # initialize sampler kwargs
  190. if cfg_sampler['type'] in ['distributed', "naive", "random"]:
  191. sampler_kwargs = {'dataset': dataset}
  192. else:
  193. sampler_kwargs = {
  194. 'dataset': dataset,
  195. 'batch_size': batch_size,
  196. 'total_iter': total_iter,
  197. 'last_iter': cfg_dataset['last_iter']
  198. }
  199. cfg_sampler['kwargs'].update(sampler_kwargs)
  200. cfg_dataset['max_iter'] = total_iter
  201. cfg_dataset.pop('dataset')
  202. return sampler_dict[cfg_sampler['type']](**cfg_sampler['kwargs'])