icfgpedes.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import os.path as op
  2. from typing import List
  3. from utils.iotools import read_json
  4. from .bases import BaseDataset
  5. class ICFGPEDES(BaseDataset):
  6. """
  7. ICFG-PEDES
  8. Reference:
  9. Semantically Self-Aligned Network for Text-to-Image Part-aware Person Re-identification arXiv 2107
  10. URL: http://arxiv.org/abs/2107.12666
  11. Dataset statistics:
  12. # identities: 4102
  13. # images: 34674 (train) + 4855 (query) + 14993 (gallery)
  14. # cameras: 15
  15. """
  16. dataset_dir = 'ICFG-PEDES'
  17. def __init__(self, root='', verbose=True):
  18. super(ICFGPEDES, self).__init__()
  19. self.dataset_dir = op.join(root, self.dataset_dir)
  20. self.img_dir = op.join(self.dataset_dir, 'imgs/')
  21. self.anno_path = op.join(self.dataset_dir, 'ICFG-PEDES.json')
  22. self._check_before_run()
  23. self.train_annos, self.test_annos, self.val_annos = self._split_anno(self.anno_path)
  24. self.train, self.train_id_container = self._process_anno(self.train_annos, training=True)
  25. self.test, self.test_id_container = self._process_anno(self.test_annos)
  26. self.val, self.val_id_container = self._process_anno(self.val_annos)
  27. if verbose:
  28. self.logger.info("=> ICFG-PEDES Images and Captions are loaded")
  29. self.show_dataset_info()
  30. def _split_anno(self, anno_path: str):
  31. train_annos, test_annos, val_annos = [], [], []
  32. annos = read_json(anno_path)
  33. for anno in annos:
  34. if anno['split'] == 'train':
  35. train_annos.append(anno)
  36. elif anno['split'] == 'test':
  37. test_annos.append(anno)
  38. else:
  39. val_annos.append(anno)
  40. return train_annos, test_annos, val_annos
  41. def _process_anno(self, annos: List[dict], training=False):
  42. pid_container = set()
  43. if training:
  44. dataset = []
  45. image_id = 0
  46. for anno in annos:
  47. pid = int(anno['id'])
  48. pid_container.add(pid)
  49. img_path = op.join(self.img_dir, anno['file_path'])
  50. captions = anno['captions'] # caption list
  51. for caption in captions:
  52. dataset.append((pid, image_id, img_path, caption))
  53. image_id += 1
  54. for idx, pid in enumerate(pid_container):
  55. # check pid begin from 0 and no break
  56. assert idx == pid, f"idx: {idx} and pid: {pid} are not match"
  57. return dataset, pid_container
  58. else:
  59. dataset = {}
  60. img_paths = []
  61. captions = []
  62. image_pids = []
  63. caption_pids = []
  64. for anno in annos:
  65. pid = int(anno['id'])
  66. pid_container.add(pid)
  67. img_path = op.join(self.img_dir, anno['file_path'])
  68. img_paths.append(img_path)
  69. image_pids.append(pid)
  70. caption_list = anno['captions'] # caption list
  71. for caption in caption_list:
  72. captions.append(caption)
  73. caption_pids.append(pid)
  74. dataset = {
  75. "image_pids": image_pids,
  76. "img_paths": img_paths,
  77. "caption_pids": caption_pids,
  78. "captions": captions
  79. }
  80. return dataset, pid_container
  81. def _check_before_run(self):
  82. """Check if all files are available before going deeper"""
  83. if not op.exists(self.dataset_dir):
  84. raise RuntimeError("'{}' is not available".format(self.dataset_dir))
  85. if not op.exists(self.img_dir):
  86. raise RuntimeError("'{}' is not available".format(self.img_dir))
  87. if not op.exists(self.anno_path):
  88. raise RuntimeError("'{}' is not available".format(self.anno_path))