build.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import logging
  2. import torch
  3. import torchvision.transforms as T
  4. from torch.utils.data import DataLoader
  5. # from datasets.sampler import RandomIdentitySampler
  6. # from datasets.sampler_ddp import RandomIdentitySampler_DDP
  7. from utils.comm import get_world_size
  8. from .cuhkpedes import CUHKPEDES
  9. from .bases import ImageDataset, TextDataset, ImageTextDataset, ImageTextMLMDataset
  10. # __factory = {'CUHK-PEDES': CUHKPEDES, 'ICFG-PEDES': ICFGPEDES, 'RSTPReid': RSTPReid}
  11. __factory = {'CUHK-PEDES': CUHKPEDES}
  12. def build_transforms(img_size=(384, 128), aug=False, is_train=True):
  13. height, width = img_size
  14. mean = [0.48145466, 0.4578275, 0.40821073]
  15. std = [0.26862954, 0.26130258, 0.27577711]
  16. if not is_train:
  17. transform = T.Compose([
  18. T.Resize((height, width)),
  19. T.ToTensor(),
  20. T.Normalize(mean=mean, std=std),
  21. ])
  22. return transform
  23. # transform for training
  24. if aug:
  25. transform = T.Compose([
  26. T.Resize((height, width)),
  27. T.RandomHorizontalFlip(0.5),
  28. T.Pad(10),
  29. T.RandomCrop((height, width)),
  30. T.ToTensor(),
  31. T.Normalize(mean=mean, std=std),
  32. T.RandomErasing(scale=(0.02, 0.4), value=mean),
  33. ])
  34. else:
  35. transform = T.Compose([
  36. T.Resize((height, width)),
  37. T.RandomHorizontalFlip(0.5),
  38. T.ToTensor(),
  39. T.Normalize(mean=mean, std=std),
  40. ])
  41. return transform
  42. def collate(batch):
  43. keys = set([key for b in batch for key in b.keys()])
  44. # turn list of dicts data structure to dict of lists data structure
  45. dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys}
  46. batch_tensor_dict = {}
  47. for k, v in dict_batch.items():
  48. if isinstance(v[0], int):
  49. batch_tensor_dict.update({k: torch.tensor(v)})
  50. elif torch.is_tensor(v[0]):
  51. batch_tensor_dict.update({k: torch.stack(v)})
  52. else:
  53. raise TypeError(f"Unexpect data type: {type(v[0])} in a batch.")
  54. return batch_tensor_dict
  55. def build_dataloader(args, tranforms=None):
  56. logger = logging.getLogger("IRRA.dataset")
  57. num_workers = args.data.num_workers
  58. dataset = __factory[args.data.dataset.meta.cuhkpedes_val.name](root=args.data.dataset.meta.cuhkpedes_val.raw_path)
  59. num_classes = len(dataset.train_id_container)
  60. val_transforms = build_transforms(img_size=(args.data.img_aug.img_size * 3, args.data.img_aug.img_size),
  61. is_train=False)
  62. # use test set as validate set
  63. ds = dataset.val
  64. val_img_set = ImageDataset(ds['image_pids'], ds['img_paths'],
  65. val_transforms)
  66. val_txt_set = TextDataset(ds['caption_pids'],
  67. ds['captions'],
  68. text_length=args.data.text_aug.max_seq_len)
  69. val_img_loader = DataLoader(val_img_set,
  70. batch_size=args.data.batch_size,
  71. shuffle=False,
  72. num_workers=num_workers)
  73. val_txt_loader = DataLoader(val_txt_set,
  74. batch_size=args.data.batch_size,
  75. shuffle=False,
  76. num_workers=num_workers)
  77. return val_img_loader, val_txt_loader, num_classes