bases.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from typing import List
  2. from torch.utils.data import Dataset
  3. import os.path as osp
  4. import logging
  5. import torch
  6. from utils.iotools import read_image
  7. from utils.simple_tokenizer import SimpleTokenizer
  8. from prettytable import PrettyTable
  9. import random
  10. import regex as re
  11. import copy
  12. class BaseDataset(object):
  13. """
  14. Base class of text to image reid dataset
  15. """
  16. logger = logging.getLogger("IRRA.dataset")
  17. def show_dataset_info(self):
  18. num_train_pids, num_train_imgs, num_train_captions = len(
  19. self.train_id_container), len(self.train_annos), len(self.train)
  20. num_test_pids, num_test_imgs, num_test_captions = len(
  21. self.test_id_container), len(self.test_annos), len(
  22. self.test['captions'])
  23. num_val_pids, num_val_imgs, num_val_captions = len(
  24. self.val_id_container), len(self.val_annos), len(
  25. self.val['captions'])
  26. # TODO use prettytable print comand line table
  27. self.logger.info(f"{self.__class__.__name__} Dataset statistics:")
  28. table = PrettyTable(['subset', 'ids', 'images', 'captions'])
  29. table.add_row(
  30. ['train', num_train_pids, num_train_imgs, num_train_captions])
  31. table.add_row(
  32. ['test', num_test_pids, num_test_imgs, num_test_captions])
  33. table.add_row(['val', num_val_pids, num_val_imgs, num_val_captions])
  34. self.logger.info('\n' + str(table))
  35. def tokenize(caption: str, tokenizer, text_length=77, truncate=True) -> torch.LongTensor:
  36. sot_token = tokenizer.encoder["<|startoftext|>"]
  37. eot_token = tokenizer.encoder["<|endoftext|>"]
  38. tokens = [sot_token] + tokenizer.encode(caption) + [eot_token]
  39. result = torch.zeros(text_length, dtype=torch.long)
  40. if len(tokens) > text_length:
  41. if truncate:
  42. tokens = tokens[:text_length]
  43. tokens[-1] = eot_token
  44. else:
  45. raise RuntimeError(
  46. f"Input {caption} is too long for context length {text_length}"
  47. )
  48. result[:len(tokens)] = torch.tensor(tokens)
  49. return result
  50. class ImageTextDataset(Dataset):
  51. def __init__(self,
  52. dataset,
  53. transform=None,
  54. text_length: int = 77,
  55. truncate: bool = True):
  56. self.dataset = dataset
  57. self.transform = transform
  58. self.text_length = text_length
  59. self.truncate = truncate
  60. self.tokenizer = SimpleTokenizer()
  61. def __len__(self):
  62. return len(self.dataset)
  63. def __getitem__(self, index):
  64. pid, image_id, img_path, caption = self.dataset[index]
  65. img = read_image(img_path)
  66. if self.transform is not None:
  67. img = self.transform(img)
  68. tokens = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate)
  69. ret = {
  70. 'pids': pid,
  71. 'image_ids': image_id,
  72. 'images': img,
  73. 'caption_ids': tokens,
  74. }
  75. return ret
  76. class ImageDataset(Dataset):
  77. def __init__(self, image_pids, img_paths, transform=None):
  78. self.image_pids = image_pids
  79. self.img_paths = img_paths
  80. self.transform = transform
  81. def __len__(self):
  82. return len(self.image_pids)
  83. def __getitem__(self, index):
  84. pid, img_path = self.image_pids[index], self.img_paths[index]
  85. img = read_image(img_path)
  86. if self.transform is not None:
  87. img = self.transform(img)
  88. return pid, img
  89. class TextDataset(Dataset):
  90. def __init__(self,
  91. caption_pids,
  92. captions,
  93. text_length: int = 77,
  94. truncate: bool = True):
  95. self.caption_pids = caption_pids
  96. self.captions = captions
  97. self.text_length = text_length
  98. self.truncate = truncate
  99. # self.tokenizer = SimpleTokenizer(bpe_path="/home/linkslinks/文档/ai/GroupViT/datasets/bpe_simple_vocab_16e6.txt.gz")
  100. self.tokenizer = SimpleTokenizer(bpe_path="/mnt/vos-s9gjtkm2/reid/groupvit/GroupViT/datasets/bpe_simple_vocab_16e6.txt.gz")
  101. def __len__(self):
  102. return len(self.caption_pids)
  103. def __getitem__(self, index):
  104. pid, caption = self.caption_pids[index], self.captions[index]
  105. caption = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate)
  106. return pid, caption
  107. class ImageTextMLMDataset(Dataset):
  108. def __init__(self,
  109. dataset,
  110. transform=None,
  111. text_length: int = 77,
  112. truncate: bool = True):
  113. self.dataset = dataset
  114. self.transform = transform
  115. self.text_length = text_length
  116. self.truncate = truncate
  117. self.tokenizer = SimpleTokenizer()
  118. def __len__(self):
  119. return len(self.dataset)
  120. def __getitem__(self, index):
  121. pid, image_id, img_path, caption = self.dataset[index]
  122. img = read_image(img_path)
  123. if self.transform is not None:
  124. img = self.transform(img)
  125. caption_tokens = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate)
  126. mlm_tokens, mlm_labels = self._build_random_masked_tokens_and_labels(caption_tokens.cpu().numpy())
  127. ret = {
  128. 'pids': pid,
  129. 'image_ids': image_id,
  130. 'images': img,
  131. 'caption_ids': caption_tokens,
  132. 'mlm_ids': mlm_tokens,
  133. 'mlm_labels': mlm_labels
  134. }
  135. return ret
  136. def _build_random_masked_tokens_and_labels(self, tokens):
  137. """
  138. Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
  139. :param tokens: list of int, tokenized sentence.
  140. :return: (list of int, list of int), masked tokens and related labels for MLM prediction
  141. """
  142. mask = self.tokenizer.encoder["<|mask|>"]
  143. token_range = list(range(1, len(self.tokenizer.encoder)-3)) # 1 ~ 49405
  144. labels = []
  145. for i, token in enumerate(tokens):
  146. if 0 < token < 49405:
  147. prob = random.random()
  148. # mask token with 15% probability
  149. if prob < 0.15:
  150. prob /= 0.15
  151. # 80% randomly change token to mask token
  152. if prob < 0.8:
  153. tokens[i] = mask
  154. # 10% randomly change token to random token
  155. elif prob < 0.9:
  156. tokens[i] = random.choice(token_range)
  157. # -> rest 10% randomly keep current token
  158. # append current token to output (we will predict these later)
  159. labels.append(token)
  160. else:
  161. # no masking token (will be ignored by loss function later)
  162. labels.append(0)
  163. else:
  164. labels.append(0)
  165. if all(l == 0 for l in labels):
  166. # at least mask 1
  167. labels[1] = tokens[1]
  168. tokens[1] = mask
  169. return torch.tensor(tokens), torch.tensor(labels)