sampler_ddp.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. from torch.utils.data.sampler import Sampler
  2. from collections import defaultdict
  3. import copy
  4. import random
  5. import numpy as np
  6. import math
  7. import torch.distributed as dist
  8. _LOCAL_PROCESS_GROUP = None
  9. import torch
  10. import pickle
  11. def _get_global_gloo_group():
  12. """
  13. Return a process group based on gloo backend, containing all the ranks
  14. The result is cached.
  15. """
  16. if dist.get_backend() == "nccl":
  17. return dist.new_group(backend="gloo")
  18. else:
  19. return dist.group.WORLD
  20. def _serialize_to_tensor(data, group):
  21. backend = dist.get_backend(group)
  22. assert backend in ["gloo", "nccl"]
  23. device = torch.device("cpu" if backend == "gloo" else "cuda")
  24. buffer = pickle.dumps(data)
  25. if len(buffer) > 1024 ** 3:
  26. print(
  27. "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
  28. dist.get_rank(), len(buffer) / (1024 ** 3), device
  29. )
  30. )
  31. storage = torch.ByteStorage.from_buffer(buffer)
  32. tensor = torch.ByteTensor(storage).to(device=device)
  33. return tensor
  34. def _pad_to_largest_tensor(tensor, group):
  35. """
  36. Returns:
  37. list[int]: size of the tensor, on each rank
  38. Tensor: padded tensor that has the max size
  39. """
  40. world_size = dist.get_world_size(group=group)
  41. assert (
  42. world_size >= 1
  43. ), "comm.gather/all_gather must be called from ranks within the given group!"
  44. local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
  45. size_list = [
  46. torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
  47. ]
  48. dist.all_gather(size_list, local_size, group=group)
  49. size_list = [int(size.item()) for size in size_list]
  50. max_size = max(size_list)
  51. # we pad the tensor because torch all_gather does not support
  52. # gathering tensors of different shapes
  53. if local_size != max_size:
  54. padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
  55. tensor = torch.cat((tensor, padding), dim=0)
  56. return size_list, tensor
  57. def all_gather(data, group=None):
  58. """
  59. Run all_gather on arbitrary picklable data (not necessarily tensors).
  60. Args:
  61. data: any picklable object
  62. group: a torch process group. By default, will use a group which
  63. contains all ranks on gloo backend.
  64. Returns:
  65. list[data]: list of data gathered from each rank
  66. """
  67. if dist.get_world_size() == 1:
  68. return [data]
  69. if group is None:
  70. group = _get_global_gloo_group()
  71. if dist.get_world_size(group) == 1:
  72. return [data]
  73. tensor = _serialize_to_tensor(data, group)
  74. size_list, tensor = _pad_to_largest_tensor(tensor, group)
  75. max_size = max(size_list)
  76. # receiving Tensor from all ranks
  77. tensor_list = [
  78. torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
  79. ]
  80. dist.all_gather(tensor_list, tensor, group=group)
  81. data_list = []
  82. for size, tensor in zip(size_list, tensor_list):
  83. buffer = tensor.cpu().numpy().tobytes()[:size]
  84. data_list.append(pickle.loads(buffer))
  85. return data_list
  86. def shared_random_seed():
  87. """
  88. Returns:
  89. int: a random number that is the same across all workers.
  90. If workers need a shared RNG, they can use this shared seed to
  91. create one.
  92. All workers must call this function, otherwise it will deadlock.
  93. """
  94. ints = np.random.randint(2 ** 31)
  95. all_ints = all_gather(ints)
  96. return all_ints[0]
  97. class RandomIdentitySampler_DDP(Sampler):
  98. """
  99. Randomly sample N identities, then for each identity,
  100. randomly sample K instances, therefore batch size is N*K.
  101. Args:
  102. - data_source (list): list of (img_path, pid, camid).
  103. - num_instances (int): number of instances per identity in a batch.
  104. - batch_size (int): number of examples in a batch.
  105. """
  106. def __init__(self, data_source, batch_size, num_instances):
  107. self.data_source = data_source
  108. self.batch_size = batch_size
  109. self.world_size = dist.get_world_size()
  110. self.num_instances = num_instances
  111. self.mini_batch_size = self.batch_size // self.world_size
  112. self.num_pids_per_batch = self.mini_batch_size // self.num_instances
  113. self.index_dic = defaultdict(list)
  114. for index, (pid, _, _, _) in enumerate(self.data_source):
  115. self.index_dic[pid].append(index)
  116. self.pids = list(self.index_dic.keys())
  117. # estimate number of examples in an epoch
  118. self.length = 0
  119. for pid in self.pids:
  120. idxs = self.index_dic[pid]
  121. num = len(idxs)
  122. if num < self.num_instances:
  123. num = self.num_instances
  124. self.length += num - num % self.num_instances
  125. self.rank = dist.get_rank()
  126. #self.world_size = dist.get_world_size()
  127. self.length //= self.world_size
  128. def __iter__(self):
  129. seed = shared_random_seed()
  130. np.random.seed(seed)
  131. self._seed = int(seed)
  132. final_idxs = self.sample_list()
  133. length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size))
  134. #final_idxs = final_idxs[self.rank * length:(self.rank + 1) * length]
  135. final_idxs = self.__fetch_current_node_idxs(final_idxs, length)
  136. self.length = len(final_idxs)
  137. return iter(final_idxs)
  138. def __fetch_current_node_idxs(self, final_idxs, length):
  139. total_num = len(final_idxs)
  140. block_num = (length // self.mini_batch_size)
  141. index_target = []
  142. for i in range(0, block_num * self.world_size, self.world_size):
  143. index = range(self.mini_batch_size * self.rank + self.mini_batch_size * i, min(self.mini_batch_size * self.rank + self.mini_batch_size * (i+1), total_num))
  144. index_target.extend(index)
  145. index_target_npy = np.array(index_target)
  146. final_idxs = list(np.array(final_idxs)[index_target_npy])
  147. return final_idxs
  148. def sample_list(self):
  149. #np.random.seed(self._seed)
  150. avai_pids = copy.deepcopy(self.pids)
  151. batch_idxs_dict = {}
  152. batch_indices = []
  153. while len(avai_pids) >= self.num_pids_per_batch:
  154. selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist()
  155. for pid in selected_pids:
  156. if pid not in batch_idxs_dict:
  157. idxs = copy.deepcopy(self.index_dic[pid])
  158. if len(idxs) < self.num_instances:
  159. idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist()
  160. np.random.shuffle(idxs)
  161. batch_idxs_dict[pid] = idxs
  162. avai_idxs = batch_idxs_dict[pid]
  163. for _ in range(self.num_instances):
  164. batch_indices.append(avai_idxs.pop(0))
  165. if len(avai_idxs) < self.num_instances: avai_pids.remove(pid)
  166. return batch_indices
  167. def __len__(self):
  168. return self.length