build.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import logging
  2. import torch
  3. import torchvision.transforms as T
  4. from torch.utils.data import DataLoader
  5. from datasets.sampler import RandomIdentitySampler
  6. from datasets.sampler_ddp import RandomIdentitySampler_DDP
  7. from torch.utils.data.distributed import DistributedSampler
  8. from utils.comm import get_world_size
  9. from .bases import ImageDataset, TextDataset, ImageTextDataset, ImageTextMLMDataset
  10. from .cuhkpedes import CUHKPEDES
  11. from .icfgpedes import ICFGPEDES
  12. from .rstpreid import RSTPReid
  13. __factory = {'CUHK-PEDES': CUHKPEDES, 'ICFG-PEDES': ICFGPEDES, 'RSTPReid': RSTPReid}
  14. def build_transforms(img_size=(384, 128), aug=False, is_train=True):
  15. height, width = img_size
  16. mean = [0.48145466, 0.4578275, 0.40821073]
  17. std = [0.26862954, 0.26130258, 0.27577711]
  18. if not is_train:
  19. transform = T.Compose([
  20. T.Resize((height, width)),
  21. T.ToTensor(),
  22. T.Normalize(mean=mean, std=std),
  23. ])
  24. return transform
  25. # transform for training
  26. if aug:
  27. transform = T.Compose([
  28. T.Resize((height, width)),
  29. T.RandomHorizontalFlip(0.5),
  30. T.Pad(10),
  31. T.RandomCrop((height, width)),
  32. T.ToTensor(),
  33. T.Normalize(mean=mean, std=std),
  34. T.RandomErasing(scale=(0.02, 0.4), value=mean),
  35. ])
  36. else:
  37. transform = T.Compose([
  38. T.Resize((height, width)),
  39. T.RandomHorizontalFlip(0.5),
  40. T.ToTensor(),
  41. T.Normalize(mean=mean, std=std),
  42. ])
  43. return transform
  44. def collate(batch):
  45. keys = set([key for b in batch for key in b.keys()])
  46. # turn list of dicts data structure to dict of lists data structure
  47. dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys}
  48. batch_tensor_dict = {}
  49. for k, v in dict_batch.items():
  50. if isinstance(v[0], int):
  51. batch_tensor_dict.update({k: torch.tensor(v)})
  52. elif torch.is_tensor(v[0]):
  53. batch_tensor_dict.update({k: torch.stack(v)})
  54. else:
  55. raise TypeError(f"Unexpect data type: {type(v[0])} in a batch.")
  56. return batch_tensor_dict
  57. def build_dataloader(args, tranforms=None):
  58. logger = logging.getLogger("IRRA.dataset")
  59. num_workers = args.num_workers
  60. dataset = __factory[args.dataset_name](root=args.root_dir)
  61. num_classes = len(dataset.train_id_container)
  62. if args.training:
  63. train_transforms = build_transforms(img_size=args.img_size,
  64. aug=args.img_aug,
  65. is_train=True)
  66. val_transforms = build_transforms(img_size=args.img_size,
  67. is_train=False)
  68. if args.MLM:
  69. train_set = ImageTextMLMDataset(dataset.train,
  70. train_transforms,
  71. text_length=args.text_length)
  72. else:
  73. train_set = ImageTextDataset(dataset.train,
  74. train_transforms,
  75. text_length=args.text_length)
  76. if args.sampler == 'identity':
  77. if args.distributed:
  78. logger.info('using ddp random identity sampler')
  79. logger.info('DISTRIBUTED TRAIN START')
  80. mini_batch_size = args.batch_size // get_world_size()
  81. # TODO wait to fix bugs
  82. data_sampler = RandomIdentitySampler_DDP(
  83. dataset.train, args.batch_size, args.num_instance)
  84. batch_sampler = torch.utils.data.sampler.BatchSampler(
  85. data_sampler, mini_batch_size, True)
  86. else:
  87. logger.info(
  88. f'using random identity sampler: batch_size: {args.batch_size}, id: {args.batch_size // args.num_instance}, instance: {args.num_instance}'
  89. )
  90. train_loader = DataLoader(train_set,
  91. batch_size=args.batch_size,
  92. sampler=RandomIdentitySampler(
  93. dataset.train, args.batch_size,
  94. args.num_instance),
  95. num_workers=num_workers,
  96. collate_fn=collate)
  97. elif args.sampler == 'random':
  98. # TODO add distributed condition
  99. logger.info('using random sampler')
  100. train_loader = DataLoader(train_set,
  101. batch_size=args.batch_size,
  102. shuffle=True,
  103. num_workers=num_workers,
  104. collate_fn=collate)
  105. else:
  106. logger.error('unsupported sampler! expected softmax or triplet but got {}'.format(args.sampler))
  107. # use test set as validate set
  108. ds = dataset.val if args.val_dataset == 'val' else dataset.test
  109. val_img_set = ImageDataset(ds['image_pids'], ds['img_paths'],
  110. val_transforms)
  111. val_txt_set = TextDataset(ds['caption_pids'],
  112. ds['captions'],
  113. text_length=args.text_length)
  114. val_img_loader = DataLoader(val_img_set,
  115. batch_size=args.batch_size,
  116. shuffle=False,
  117. num_workers=num_workers)
  118. val_txt_loader = DataLoader(val_txt_set,
  119. batch_size=args.batch_size,
  120. shuffle=False,
  121. num_workers=num_workers)
  122. return train_loader, val_img_loader, val_txt_loader, num_classes
  123. else:
  124. # build dataloader for testing
  125. if tranforms:
  126. test_transforms = tranforms
  127. else:
  128. test_transforms = build_transforms(img_size=args.img_size,
  129. is_train=False)
  130. ds = dataset.test
  131. test_img_set = ImageDataset(ds['image_pids'], ds['img_paths'],
  132. test_transforms)
  133. test_txt_set = TextDataset(ds['caption_pids'],
  134. ds['captions'],
  135. text_length=args.text_length)
  136. test_img_loader = DataLoader(test_img_set,
  137. batch_size=args.test_batch_size,
  138. shuffle=False,
  139. num_workers=num_workers)
  140. test_txt_loader = DataLoader(test_txt_set,
  141. batch_size=args.test_batch_size,
  142. shuffle=False,
  143. num_workers=num_workers)
  144. return test_img_loader, test_txt_loader, num_classes