cuhkpedes.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import os.path as op
  2. from typing import List
  3. from utils.iotools import read_json
  4. from .bases import BaseDataset
  5. class CUHKPEDES(BaseDataset):
  6. """
  7. CUHK-PEDES
  8. Reference:
  9. Person Search With Natural Language Description (CVPR 2017)
  10. URL: https://openaccess.thecvf.com/content_cvpr_2017/html/Li_Person_Search_With_CVPR_2017_paper.html
  11. Dataset statistics:
  12. ### identities: 13003
  13. ### images: 40206, (train) (test) (val)
  14. ### captions:
  15. ### 9 images have more than 2 captions
  16. ### 4 identity have only one image
  17. annotation format:
  18. [{'split', str,
  19. 'captions', list,
  20. 'file_path', str,
  21. 'processed_tokens', list,
  22. 'id', int}...]
  23. """
  24. dataset_dir = 'CUHK-PEDES'
  25. def __init__(self, root='', verbose=True):
  26. super(CUHKPEDES, self).__init__()
  27. self.dataset_dir = op.join(root, self.dataset_dir)
  28. self.img_dir = op.join(self.dataset_dir, 'imgs/')
  29. self.anno_path = op.join(self.dataset_dir, 'reid_raw.json')
  30. self._check_before_run()
  31. self.train_annos, self.test_annos, self.val_annos = self._split_anno(self.anno_path)
  32. self.train, self.train_id_container = self._process_anno(self.train_annos, training=True)
  33. self.test, self.test_id_container = self._process_anno(self.test_annos)
  34. self.val, self.val_id_container = self._process_anno(self.val_annos)
  35. if verbose:
  36. self.logger.info("=> CUHK-PEDES Images and Captions are loaded")
  37. self.show_dataset_info()
  38. def _split_anno(self, anno_path: str):
  39. train_annos, test_annos, val_annos = [], [], []
  40. annos = read_json(anno_path)
  41. for anno in annos:
  42. if anno['split'] == 'train':
  43. train_annos.append(anno)
  44. elif anno['split'] == 'test':
  45. test_annos.append(anno)
  46. else:
  47. val_annos.append(anno)
  48. return train_annos, test_annos, val_annos
  49. def _process_anno(self, annos: List[dict], training=False):
  50. pid_container = set()
  51. if training:
  52. dataset = []
  53. image_id = 0
  54. for anno in annos:
  55. pid = int(anno['id']) - 1 # make pid begin from 0
  56. pid_container.add(pid)
  57. img_path = op.join(self.img_dir, anno['file_path'])
  58. captions = anno['captions'] # caption list
  59. for caption in captions:
  60. dataset.append((pid, image_id, img_path, caption))
  61. image_id += 1
  62. for idx, pid in enumerate(pid_container):
  63. # check pid begin from 0 and no break
  64. assert idx == pid, f"idx: {idx} and pid: {pid} are not match"
  65. return dataset, pid_container
  66. else:
  67. dataset = {}
  68. img_paths = []
  69. captions = []
  70. image_pids = []
  71. caption_pids = []
  72. for anno in annos:
  73. pid = int(anno['id'])
  74. pid_container.add(pid)
  75. img_path = op.join(self.img_dir, anno['file_path'])
  76. img_paths.append(img_path)
  77. image_pids.append(pid)
  78. caption_list = anno['captions'] # caption list
  79. for caption in caption_list:
  80. captions.append(caption)
  81. caption_pids.append(pid)
  82. dataset = {
  83. "image_pids": image_pids,
  84. "img_paths": img_paths,
  85. "caption_pids": caption_pids,
  86. "captions": captions
  87. }
  88. return dataset, pid_container
  89. def _check_before_run(self):
  90. """Check if all files are available before going deeper"""
  91. if not op.exists(self.dataset_dir):
  92. raise RuntimeError("'{}' is not available".format(self.dataset_dir))
  93. if not op.exists(self.img_dir):
  94. raise RuntimeError("'{}' is not available".format(self.img_dir))
  95. if not op.exists(self.anno_path):
  96. raise RuntimeError("'{}' is not available".format(self.anno_path))