@@ -0,0 +1,197 @@
+from torch.utils.data.sampler import Sampler
+from collections import defaultdict
+import copy
+import random
+import numpy as np
+import math
+import torch.distributed as dist
+import torch
+import pickle
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ """
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo")
+ else:
+ return dist.group.WORLD
+def _serialize_to_tensor(data, group):
+ backend = dist.get_backend(group)
+ assert backend in ["gloo", "nccl"]
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
+ buffer = pickle.dumps(data)
+ if len(buffer) > 1024 ** 3:
+ print(
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
+ dist.get_rank(), len(buffer) / (1024 ** 3), device
+ )
+ )
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to(device=device)
+ return tensor
+def _pad_to_largest_tensor(tensor, group):
+ """
+ Returns:
+ list[int]: size of the tensor, on each rank
+ Tensor: padded tensor that has the max size
+ """
+ world_size = dist.get_world_size(group=group)
+ assert (
+ world_size >= 1
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
+ size_list = [
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
+ ]
+ dist.all_gather(size_list, local_size, group=group)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ if local_size != max_size:
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
+ tensor = torch.cat((tensor, padding), dim=0)
+ return size_list, tensor
+def all_gather(data, group=None):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
+ Args:
+ data: any picklable object
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ if dist.get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group) == 1:
+ return [data]
+ tensor = _serialize_to_tensor(data, group)
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+ max_size = max(size_list)
+ # receiving Tensor from all ranks
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.all_gather(tensor_list, tensor, group=group)
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+ return data_list
+def shared_random_seed():
+ """
+ Returns:
+ int: a random number that is the same across all workers.
+ If workers need a shared RNG, they can use this shared seed to
+ create one.
+ All workers must call this function, otherwise it will deadlock.
+ """
+ ints = np.random.randint(2 ** 31)
+ all_ints = all_gather(ints)
+ return all_ints[0]
+class RandomIdentitySampler_DDP(Sampler):
+ """
+ Randomly sample N identities, then for each identity,
+ randomly sample K instances, therefore batch size is N*K.
+ Args:
+ - data_source (list): list of (img_path, pid, camid).
+ - num_instances (int): number of instances per identity in a batch.
+ - batch_size (int): number of examples in a batch.
+ """
+ def __init__(self, data_source, batch_size, num_instances):
+ self.data_source = data_source
+ self.batch_size = batch_size
+ self.world_size = dist.get_world_size()
+ self.num_instances = num_instances
+ self.mini_batch_size = self.batch_size // self.world_size
+ self.num_pids_per_batch = self.mini_batch_size // self.num_instances
+ self.index_dic = defaultdict(list)
+ for index, (pid, _, _, _) in enumerate(self.data_source):
+ self.index_dic[pid].append(index)
+ self.pids = list(self.index_dic.keys())
+ # estimate number of examples in an epoch
+ self.length = 0
+ for pid in self.pids:
+ idxs = self.index_dic[pid]
+ num = len(idxs)
+ if num < self.num_instances:
+ num = self.num_instances
+ self.length += num - num % self.num_instances
+ self.rank = dist.get_rank()
+ #self.world_size = dist.get_world_size()
+ self.length //= self.world_size
+ def __iter__(self):
+ seed = shared_random_seed()
+ np.random.seed(seed)
+ self._seed = int(seed)
+ final_idxs = self.sample_list()
+ length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size))
+ #final_idxs = final_idxs[self.rank * length:(self.rank + 1) * length]
+ final_idxs = self.__fetch_current_node_idxs(final_idxs, length)
+ self.length = len(final_idxs)
+ return iter(final_idxs)
+ def __fetch_current_node_idxs(self, final_idxs, length):
+ total_num = len(final_idxs)
+ block_num = (length // self.mini_batch_size)
+ index_target = []
+ for i in range(0, block_num * self.world_size, self.world_size):
+ index = range(self.mini_batch_size * self.rank + self.mini_batch_size * i, min(self.mini_batch_size * self.rank + self.mini_batch_size * (i+1), total_num))
+ index_target.extend(index)
+ index_target_npy = np.array(index_target)
+ final_idxs = list(np.array(final_idxs)[index_target_npy])
+ return final_idxs
+ def sample_list(self):
+ #np.random.seed(self._seed)
+ avai_pids = copy.deepcopy(self.pids)
+ batch_idxs_dict = {}
+ batch_indices = []
+ while len(avai_pids) >= self.num_pids_per_batch:
+ selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist()
+ for pid in selected_pids:
+ if pid not in batch_idxs_dict:
+ idxs = copy.deepcopy(self.index_dic[pid])
+ if len(idxs) < self.num_instances:
+ idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist()
+ np.random.shuffle(idxs)
+ batch_idxs_dict[pid] = idxs
+ avai_idxs = batch_idxs_dict[pid]
+ for _ in range(self.num_instances):
+ batch_indices.append(avai_idxs.pop(0))
+ if len(avai_idxs) < self.num_instances: avai_pids.remove(pid)
+ return batch_indices
+ def __len__(self):
+ return self.length