|
@@ -0,0 +1,197 @@
|
|
|
+from torch.utils.data.sampler import Sampler
|
|
|
+from collections import defaultdict
|
|
|
+import copy
|
|
|
+import random
|
|
|
+import numpy as np
|
|
|
+import math
|
|
|
+import torch.distributed as dist
|
|
|
+_LOCAL_PROCESS_GROUP = None
|
|
|
+import torch
|
|
|
+import pickle
|
|
|
+
|
|
|
+def _get_global_gloo_group():
|
|
|
+ """
|
|
|
+ Return a process group based on gloo backend, containing all the ranks
|
|
|
+ The result is cached.
|
|
|
+ """
|
|
|
+ if dist.get_backend() == "nccl":
|
|
|
+ return dist.new_group(backend="gloo")
|
|
|
+ else:
|
|
|
+ return dist.group.WORLD
|
|
|
+
|
|
|
+def _serialize_to_tensor(data, group):
|
|
|
+ backend = dist.get_backend(group)
|
|
|
+ assert backend in ["gloo", "nccl"]
|
|
|
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
|
|
|
+
|
|
|
+ buffer = pickle.dumps(data)
|
|
|
+ if len(buffer) > 1024 ** 3:
|
|
|
+ print(
|
|
|
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
|
|
|
+ dist.get_rank(), len(buffer) / (1024 ** 3), device
|
|
|
+ )
|
|
|
+ )
|
|
|
+ storage = torch.ByteStorage.from_buffer(buffer)
|
|
|
+ tensor = torch.ByteTensor(storage).to(device=device)
|
|
|
+ return tensor
|
|
|
+
|
|
|
+def _pad_to_largest_tensor(tensor, group):
|
|
|
+ """
|
|
|
+ Returns:
|
|
|
+ list[int]: size of the tensor, on each rank
|
|
|
+ Tensor: padded tensor that has the max size
|
|
|
+ """
|
|
|
+ world_size = dist.get_world_size(group=group)
|
|
|
+ assert (
|
|
|
+ world_size >= 1
|
|
|
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
|
|
|
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
|
|
|
+ size_list = [
|
|
|
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
|
|
|
+ ]
|
|
|
+ dist.all_gather(size_list, local_size, group=group)
|
|
|
+ size_list = [int(size.item()) for size in size_list]
|
|
|
+
|
|
|
+ max_size = max(size_list)
|
|
|
+
|
|
|
+ # we pad the tensor because torch all_gather does not support
|
|
|
+ # gathering tensors of different shapes
|
|
|
+ if local_size != max_size:
|
|
|
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
|
|
|
+ tensor = torch.cat((tensor, padding), dim=0)
|
|
|
+ return size_list, tensor
|
|
|
+
|
|
|
+def all_gather(data, group=None):
|
|
|
+ """
|
|
|
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
|
|
|
+ Args:
|
|
|
+ data: any picklable object
|
|
|
+ group: a torch process group. By default, will use a group which
|
|
|
+ contains all ranks on gloo backend.
|
|
|
+ Returns:
|
|
|
+ list[data]: list of data gathered from each rank
|
|
|
+ """
|
|
|
+ if dist.get_world_size() == 1:
|
|
|
+ return [data]
|
|
|
+ if group is None:
|
|
|
+ group = _get_global_gloo_group()
|
|
|
+ if dist.get_world_size(group) == 1:
|
|
|
+ return [data]
|
|
|
+
|
|
|
+ tensor = _serialize_to_tensor(data, group)
|
|
|
+
|
|
|
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
|
|
+ max_size = max(size_list)
|
|
|
+
|
|
|
+ # receiving Tensor from all ranks
|
|
|
+ tensor_list = [
|
|
|
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
|
|
|
+ ]
|
|
|
+ dist.all_gather(tensor_list, tensor, group=group)
|
|
|
+
|
|
|
+ data_list = []
|
|
|
+ for size, tensor in zip(size_list, tensor_list):
|
|
|
+ buffer = tensor.cpu().numpy().tobytes()[:size]
|
|
|
+ data_list.append(pickle.loads(buffer))
|
|
|
+
|
|
|
+ return data_list
|
|
|
+
|
|
|
+def shared_random_seed():
|
|
|
+ """
|
|
|
+ Returns:
|
|
|
+ int: a random number that is the same across all workers.
|
|
|
+ If workers need a shared RNG, they can use this shared seed to
|
|
|
+ create one.
|
|
|
+ All workers must call this function, otherwise it will deadlock.
|
|
|
+ """
|
|
|
+ ints = np.random.randint(2 ** 31)
|
|
|
+ all_ints = all_gather(ints)
|
|
|
+ return all_ints[0]
|
|
|
+
|
|
|
+class RandomIdentitySampler_DDP(Sampler):
|
|
|
+ """
|
|
|
+ Randomly sample N identities, then for each identity,
|
|
|
+ randomly sample K instances, therefore batch size is N*K.
|
|
|
+ Args:
|
|
|
+ - data_source (list): list of (img_path, pid, camid).
|
|
|
+ - num_instances (int): number of instances per identity in a batch.
|
|
|
+ - batch_size (int): number of examples in a batch.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, data_source, batch_size, num_instances):
|
|
|
+ self.data_source = data_source
|
|
|
+ self.batch_size = batch_size
|
|
|
+ self.world_size = dist.get_world_size()
|
|
|
+ self.num_instances = num_instances
|
|
|
+ self.mini_batch_size = self.batch_size // self.world_size
|
|
|
+ self.num_pids_per_batch = self.mini_batch_size // self.num_instances
|
|
|
+ self.index_dic = defaultdict(list)
|
|
|
+
|
|
|
+ for index, (pid, _, _, _) in enumerate(self.data_source):
|
|
|
+ self.index_dic[pid].append(index)
|
|
|
+ self.pids = list(self.index_dic.keys())
|
|
|
+
|
|
|
+ # estimate number of examples in an epoch
|
|
|
+ self.length = 0
|
|
|
+ for pid in self.pids:
|
|
|
+ idxs = self.index_dic[pid]
|
|
|
+ num = len(idxs)
|
|
|
+ if num < self.num_instances:
|
|
|
+ num = self.num_instances
|
|
|
+ self.length += num - num % self.num_instances
|
|
|
+
|
|
|
+ self.rank = dist.get_rank()
|
|
|
+ #self.world_size = dist.get_world_size()
|
|
|
+ self.length //= self.world_size
|
|
|
+
|
|
|
+ def __iter__(self):
|
|
|
+ seed = shared_random_seed()
|
|
|
+ np.random.seed(seed)
|
|
|
+ self._seed = int(seed)
|
|
|
+ final_idxs = self.sample_list()
|
|
|
+ length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size))
|
|
|
+ #final_idxs = final_idxs[self.rank * length:(self.rank + 1) * length]
|
|
|
+ final_idxs = self.__fetch_current_node_idxs(final_idxs, length)
|
|
|
+ self.length = len(final_idxs)
|
|
|
+ return iter(final_idxs)
|
|
|
+
|
|
|
+
|
|
|
+ def __fetch_current_node_idxs(self, final_idxs, length):
|
|
|
+ total_num = len(final_idxs)
|
|
|
+ block_num = (length // self.mini_batch_size)
|
|
|
+ index_target = []
|
|
|
+ for i in range(0, block_num * self.world_size, self.world_size):
|
|
|
+ index = range(self.mini_batch_size * self.rank + self.mini_batch_size * i, min(self.mini_batch_size * self.rank + self.mini_batch_size * (i+1), total_num))
|
|
|
+ index_target.extend(index)
|
|
|
+ index_target_npy = np.array(index_target)
|
|
|
+ final_idxs = list(np.array(final_idxs)[index_target_npy])
|
|
|
+ return final_idxs
|
|
|
+
|
|
|
+
|
|
|
+ def sample_list(self):
|
|
|
+ #np.random.seed(self._seed)
|
|
|
+ avai_pids = copy.deepcopy(self.pids)
|
|
|
+ batch_idxs_dict = {}
|
|
|
+
|
|
|
+ batch_indices = []
|
|
|
+ while len(avai_pids) >= self.num_pids_per_batch:
|
|
|
+ selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist()
|
|
|
+ for pid in selected_pids:
|
|
|
+ if pid not in batch_idxs_dict:
|
|
|
+ idxs = copy.deepcopy(self.index_dic[pid])
|
|
|
+ if len(idxs) < self.num_instances:
|
|
|
+ idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist()
|
|
|
+ np.random.shuffle(idxs)
|
|
|
+ batch_idxs_dict[pid] = idxs
|
|
|
+
|
|
|
+ avai_idxs = batch_idxs_dict[pid]
|
|
|
+ for _ in range(self.num_instances):
|
|
|
+ batch_indices.append(avai_idxs.pop(0))
|
|
|
+
|
|
|
+ if len(avai_idxs) < self.num_instances: avai_pids.remove(pid)
|
|
|
+
|
|
|
+ return batch_indices
|
|
|
+
|
|
|
+ def __len__(self):
|
|
|
+ return self.length
|
|
|
+
|