from torch.utils.data.sampler import Sampler from collections import defaultdict import copy import random import numpy as np import math import torch.distributed as dist _LOCAL_PROCESS_GROUP = None import torch import pickle def _get_global_gloo_group(): """ Return a process group based on gloo backend, containing all the ranks The result is cached. """ if dist.get_backend() == "nccl": return dist.new_group(backend="gloo") else: return dist.group.WORLD def _serialize_to_tensor(data, group): backend = dist.get_backend(group) assert backend in ["gloo", "nccl"] device = torch.device("cpu" if backend == "gloo" else "cuda") buffer = pickle.dumps(data) if len(buffer) > 1024 ** 3: print( "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( dist.get_rank(), len(buffer) / (1024 ** 3), device ) ) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to(device=device) return tensor def _pad_to_largest_tensor(tensor, group): """ Returns: list[int]: size of the tensor, on each rank Tensor: padded tensor that has the max size """ world_size = dist.get_world_size(group=group) assert ( world_size >= 1 ), "comm.gather/all_gather must be called from ranks within the given group!" local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) size_list = [ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) ] dist.all_gather(size_list, local_size, group=group) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes if local_size != max_size: padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) tensor = torch.cat((tensor, padding), dim=0) return size_list, tensor def all_gather(data, group=None): """ Run all_gather on arbitrary picklable data (not necessarily tensors). Args: data: any picklable object group: a torch process group. By default, will use a group which contains all ranks on gloo backend. Returns: list[data]: list of data gathered from each rank """ if dist.get_world_size() == 1: return [data] if group is None: group = _get_global_gloo_group() if dist.get_world_size(group) == 1: return [data] tensor = _serialize_to_tensor(data, group) size_list, tensor = _pad_to_largest_tensor(tensor, group) max_size = max(size_list) # receiving Tensor from all ranks tensor_list = [ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list ] dist.all_gather(tensor_list, tensor, group=group) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def shared_random_seed(): """ Returns: int: a random number that is the same across all workers. If workers need a shared RNG, they can use this shared seed to create one. All workers must call this function, otherwise it will deadlock. """ ints = np.random.randint(2 ** 31) all_ints = all_gather(ints) return all_ints[0] class RandomIdentitySampler_DDP(Sampler): """ Randomly sample N identities, then for each identity, randomly sample K instances, therefore batch size is N*K. Args: - data_source (list): list of (img_path, pid, camid). - num_instances (int): number of instances per identity in a batch. - batch_size (int): number of examples in a batch. """ def __init__(self, data_source, batch_size, num_instances): self.data_source = data_source self.batch_size = batch_size self.world_size = dist.get_world_size() self.num_instances = num_instances self.mini_batch_size = self.batch_size // self.world_size self.num_pids_per_batch = self.mini_batch_size // self.num_instances self.index_dic = defaultdict(list) for index, (pid, _, _, _) in enumerate(self.data_source): self.index_dic[pid].append(index) self.pids = list(self.index_dic.keys()) # estimate number of examples in an epoch self.length = 0 for pid in self.pids: idxs = self.index_dic[pid] num = len(idxs) if num < self.num_instances: num = self.num_instances self.length += num - num % self.num_instances self.rank = dist.get_rank() #self.world_size = dist.get_world_size() self.length //= self.world_size def __iter__(self): seed = shared_random_seed() np.random.seed(seed) self._seed = int(seed) final_idxs = self.sample_list() length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size)) #final_idxs = final_idxs[self.rank * length:(self.rank + 1) * length] final_idxs = self.__fetch_current_node_idxs(final_idxs, length) self.length = len(final_idxs) return iter(final_idxs) def __fetch_current_node_idxs(self, final_idxs, length): total_num = len(final_idxs) block_num = (length // self.mini_batch_size) index_target = [] for i in range(0, block_num * self.world_size, self.world_size): index = range(self.mini_batch_size * self.rank + self.mini_batch_size * i, min(self.mini_batch_size * self.rank + self.mini_batch_size * (i+1), total_num)) index_target.extend(index) index_target_npy = np.array(index_target) final_idxs = list(np.array(final_idxs)[index_target_npy]) return final_idxs def sample_list(self): #np.random.seed(self._seed) avai_pids = copy.deepcopy(self.pids) batch_idxs_dict = {} batch_indices = [] while len(avai_pids) >= self.num_pids_per_batch: selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() for pid in selected_pids: if pid not in batch_idxs_dict: idxs = copy.deepcopy(self.index_dic[pid]) if len(idxs) < self.num_instances: idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() np.random.shuffle(idxs) batch_idxs_dict[pid] = idxs avai_idxs = batch_idxs_dict[pid] for _ in range(self.num_instances): batch_indices.append(avai_idxs.pop(0)) if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) return batch_indices def __len__(self): return self.length