rstpreid.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import os.path as op
  2. from typing import List
  3. from utils.iotools import read_json
  4. from .bases import BaseDataset
  5. class RSTPReid(BaseDataset):
  6. """
  7. RSTPReid
  8. Reference:
  9. DSSL: Deep Surroundings-person Separation Learning for Text-based Person Retrieval MM 21
  10. URL: http://arxiv.org/abs/2109.05534
  11. Dataset statistics:
  12. # identities: 4101
  13. """
  14. dataset_dir = 'RSTPReid'
  15. def __init__(self, root='', verbose=True):
  16. super(RSTPReid, self).__init__()
  17. self.dataset_dir = op.join(root, self.dataset_dir)
  18. self.img_dir = op.join(self.dataset_dir, 'imgs/')
  19. self.anno_path = op.join(self.dataset_dir, 'data_captions.json')
  20. self._check_before_run()
  21. self.train_annos, self.test_annos, self.val_annos = self._split_anno(self.anno_path)
  22. self.train, self.train_id_container = self._process_anno(self.train_annos, training=True)
  23. self.test, self.test_id_container = self._process_anno(self.test_annos)
  24. self.val, self.val_id_container = self._process_anno(self.val_annos)
  25. if verbose:
  26. self.logger.info("=> RSTPReid Images and Captions are loaded")
  27. self.show_dataset_info()
  28. def _split_anno(self, anno_path: str):
  29. train_annos, test_annos, val_annos = [], [], []
  30. annos = read_json(anno_path)
  31. for anno in annos:
  32. if anno['split'] == 'train':
  33. train_annos.append(anno)
  34. elif anno['split'] == 'test':
  35. test_annos.append(anno)
  36. else:
  37. val_annos.append(anno)
  38. return train_annos, test_annos, val_annos
  39. def _process_anno(self, annos: List[dict], training=False):
  40. pid_container = set()
  41. if training:
  42. dataset = []
  43. image_id = 0
  44. for anno in annos:
  45. pid = int(anno['id'])
  46. pid_container.add(pid)
  47. img_path = op.join(self.img_dir, anno['img_path'])
  48. captions = anno['captions'] # caption list
  49. for caption in captions:
  50. dataset.append((pid, image_id, img_path, caption))
  51. image_id += 1
  52. for idx, pid in enumerate(pid_container):
  53. # check pid begin from 0 and no break
  54. assert idx == pid, f"idx: {idx} and pid: {pid} are not match"
  55. return dataset, pid_container
  56. else:
  57. dataset = {}
  58. img_paths = []
  59. captions = []
  60. image_pids = []
  61. caption_pids = []
  62. for anno in annos:
  63. pid = int(anno['id'])
  64. pid_container.add(pid)
  65. img_path = op.join(self.img_dir, anno['img_path'])
  66. img_paths.append(img_path)
  67. image_pids.append(pid)
  68. caption_list = anno['captions'] # caption list
  69. for caption in caption_list:
  70. captions.append(caption)
  71. caption_pids.append(pid)
  72. dataset = {
  73. "image_pids": image_pids,
  74. "img_paths": img_paths,
  75. "caption_pids": caption_pids,
  76. "captions": captions
  77. }
  78. return dataset, pid_container
  79. def _check_before_run(self):
  80. """Check if all files are available before going deeper"""
  81. if not op.exists(self.dataset_dir):
  82. raise RuntimeError("'{}' is not available".format(self.dataset_dir))
  83. if not op.exists(self.img_dir):
  84. raise RuntimeError("'{}' is not available".format(self.img_dir))
  85. if not op.exists(self.anno_path):
  86. raise RuntimeError("'{}' is not available".format(self.anno_path))