123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- # -------------------------------------------------------------------------
- # Written by Jilan Xu
- # -------------------------------------------------------------------------
- import torch
- from torch.utils.data.sampler import Sampler
- import linklink as link
- import math
- import numpy as np
- class DistributedSampler(Sampler):
- def __init__(self, dataset, world_size=None, rank=None, round_up=True):
- if world_size is None:
- world_size = link.get_world_size()
- if rank is None:
- rank = link.get_rank()
- self.dataset = dataset
- self.world_size = world_size
- self.rank = rank
- self.round_up = round_up
- self.epoch = 0
- self.num_samples = int(
- math.ceil(len(self.dataset) * 1.0 / self.world_size))
- if self.round_up:
- self.total_size = self.num_samples * self.world_size
- self.length = self.num_samples
- else:
- self.total_size = len(self.dataset)
- if self.rank < self.world_size-1:
- self.length = self.num_samples
- else:
- self.length = self.total_size - \
- (self.world_size-1)*self.num_samples
- def __iter__(self):
- g = torch.Generator()
- g.manual_seed(self.epoch)
- indices = list(torch.randperm(len(self.dataset), generator=g))
- if self.round_up:
- indices += indices[:(self.total_size - len(indices))]
- assert len(indices) == self.total_size
- offset = self.num_samples * self.rank
- indices = indices[offset:offset + self.num_samples]
- if self.round_up or (not self.round_up and self.rank < self.world_size-1):
- assert len(indices) == self.num_samples
- return iter(indices)
- def __len__(self):
- return self.length
- def set_epoch(self, epoch):
- self.epoch = epoch
- class DistributedGivenIterationSampler(Sampler):
- def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=0):
- if world_size is None:
- world_size = link.get_world_size()
- if rank is None:
- rank = link.get_rank()
- assert rank < world_size
- self.dataset = dataset
- self.total_iter = total_iter
- self.batch_size = batch_size
- self.world_size = world_size
- self.rank = rank
- self.last_iter = last_iter
- self.total_size = self.total_iter*self.batch_size
- self.indices = self.gen_new_list()
- self.call = 0
- def __iter__(self):
- if self.call == 0:
- self.call = 1
- return iter(self.indices[self.last_iter*self.batch_size:])
- else:
- raise RuntimeError(
- "this sampler is not designed to be called more than once!!")
- def gen_new_list(self):
- np.random.seed(0)
- all_size = self.total_size * self.world_size
- indices = np.arange(len(self.dataset))
- indices = indices[:all_size]
- num_repeat = (all_size-1) // indices.shape[0] + 1
- indices = np.tile(indices, num_repeat)
- indices = indices[:all_size]
- np.random.shuffle(indices)
- beg = self.total_size * self.rank
- indices = indices[beg:beg+self.total_size]
- assert len(indices) == self.total_size
- return indices
- def __len__(self):
- # note here we do not take last iter into consideration, since __len__
- # should only be used for displaying, the correct remaining size is
- # handled by dataloader
- return self.total_size
- class DistributedEpochSampler(Sampler):
- def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=0):
- if world_size is None:
- world_size = link.get_world_size()
- if rank is None:
- rank = link.get_rank()
- assert rank < world_size
- self.dataset = dataset
- self.total_iter = total_iter
- self.batch_size = batch_size
- self.world_size = world_size
- self.rank = rank
- self.last_iter = last_iter
- self.all_size_single = self.total_iter * self.batch_size
- self.indices = self.gen_new_list()
- self.call = 0
- def __iter__(self):
- if self.call == 0:
- self.call = 1
- return iter(self.indices[self.last_iter*self.batch_size:])
- else:
- raise RuntimeError(
- "this sampler is not designed to be called more than once!!")
- def get_one_epoch_self_part(self):
- num = len(self.dataset)
- indices = np.arange(num)
- extra_indices = np.random.choice(
- num, self.extra_per_epoch, replace=False)
- indices = np.concatenate((indices, extra_indices))
- np.random.shuffle(indices)
- assert len(indices) % (self.world_size * self.batch_size) == 0
- num_single = len(indices) // self.world_size
- return indices[self.rank*num_single:(self.rank+1)*num_single]
- def gen_new_list(self):
- np.random.seed(0)
- self.all_num = self.total_iter * self.batch_size * self.world_size
- iter_per_epoch = (len(self.dataset) -
- 1) // (self.batch_size * self.world_size) + 1
- self.num_per_epoch = iter_per_epoch * self.batch_size * self.world_size
- self.extra_per_epoch = self.num_per_epoch - len(self.dataset)
- repeat = (self.all_num - 1) // self.num_per_epoch + 1
- indices = []
- for i in range(repeat):
- indice = self.get_one_epoch_self_part()
- indices.append(indice)
- indices = np.concatenate(indices)
- indices = indices[:self.all_size_single]
- assert len(indices) == self.all_size_single
- return indices
- def __len__(self):
- return self.all_size_single
- class RankedGivenIterationSampler(Sampler):
- def __init__(self, dataset, total_iter, batch_size, last_iter=0):
- self.dataset = dataset
- self.total_iter = total_iter
- self.batch_size = batch_size
- self.last_iter = last_iter
- self.total_size = self.total_iter*self.batch_size
- self.cur_size = self.last_iter * self.batch_size
- # self.indices = self.gen_new_list()
- self.indices = np.arange(len(self.dataset))
- self.call = 0
- def indice_generator(self):
- np.random.shuffle(self.indices)
- while self.cur_size < self.total_size:
- #np.random.shuffle(self.indices)
- remaining_size = self.total_size - self.cur_size
- indices = self.indices[:remaining_size]
- self.cur_size += len(indices)
- for item in indices:
- yield item
-
- def __iter__(self):
- if self.call == 0:
- self.call = 1
- return self.indice_generator()
- else:
- raise RuntimeError("this sampler is not designed to be called more than once!!")
- def __len__(self):
- # note here we do not take last iter into consideration, since __len__
- # should only be used for displaying, the correct remaining size is
- # handled by dataloader
- return self.total_size
- sampler_dict = {
- 'distributed': DistributedSampler,
- 'distributed_iteration': DistributedGivenIterationSampler,
- 'distributed_epoch': DistributedEpochSampler,
- 'ranked_iteration': RankedGivenIterationSampler
- }
- def build_sampler(dataset, cfg_sampler, cfg_dataset):
- batch_size = cfg_dataset['batch_size']
- # check step type: iteration or epoch ?
- if not getattr(cfg_dataset, 'max_iter', False):
- world_size = link.get_world_size()
- iter_per_epoch = (len(dataset) - 1) // (batch_size * world_size) + 1
- if cfg_sampler['type'] == "naive":
- total_iter = cfg_dataset['max_epoch'] * ((len(dataset) - 1) // batch_size + 1) #125200
- else:
- total_iter = cfg_dataset['max_epoch'] * iter_per_epoch
- else:
- total_iter = cfg_dataset['max_iter']
- # initialize sampler kwargs
- if cfg_sampler['type'] in ['distributed', "naive", "random"]:
- sampler_kwargs = {'dataset': dataset}
- else:
- sampler_kwargs = {
- 'dataset': dataset,
- 'batch_size': batch_size,
- 'total_iter': total_iter,
- 'last_iter': cfg_dataset['last_iter']
- }
- cfg_sampler['kwargs'].update(sampler_kwargs)
- cfg_dataset['max_iter'] = total_iter
- cfg_dataset.pop('dataset')
- return sampler_dict[cfg_sampler['type']](**cfg_sampler['kwargs'])
|