clip_dataset.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # -------------------------------------------------------------------------
  2. # Written by Jilan Xu
  3. # -------------------------------------------------------------------------
  4. from re import L
  5. import torch
  6. import json
  7. import os.path as osp
  8. import requests
  9. import numpy as np
  10. import time
  11. import ast
  12. from typing import List
  13. from torch.utils.data import Dataset
  14. import random
  15. import os
  16. import pandas as pd
  17. import omegaconf
  18. import clip
  19. from .tokenizer import SimpleTokenizer
  20. from .imagenet_template import full_imagenet_templates
  21. from nltk.stem import WordNetLemmatizer
  22. from PIL import Image
  23. from ipdb import set_trace
  24. import io
  25. lemmatizer = WordNetLemmatizer()
  26. ### frequently appeared 100 entities ###
  27. TOP_CLASSES_1=[
  28. 'people', 'man', 'men', 'woman', 'women', 'girl', 'boy', 'lady', 'kid', 'child', 'children', 'baby', 'student', 'bride', 'groom', 'couple', 'prince', 'princess', \
  29. 'car', 'bus', 'truck', 'motorcycle', 'train', 'bicycle', 'boat', 'aeroplane', 'airplane', 'motorbike', 'bike',\
  30. 'cup', 'bottle', 'bowl', 'knife', 'spoon', 'glass', 'fork',\
  31. 'chair', 'table', 'bench', 'clock', 'laptop', 'light', 'vase', 'plant', 'remote', 'microwave', 'toaster', 'oven','mouse', 'keyboard','sofa', 'monitor','desk', 'tv','TV', 'couch', 'flower','refrigerator', \
  32. 'house', 'building', 'hotel',\
  33. 'handbag', 'umbrella','book', 'backpack', 'phone', 'shirt', 'tie', 'suitcase','T-shirt', 'bag', 'box', \
  34. 'sink','bed','toilet',\
  35. 'cat','dog', 'horse', 'bird','cow', 'sheep' ,'elephant', 'bear', 'zebra', 'giraffe', \
  36. 'ball', 'racket', 'skateboard', 'skis', 'snowboard', 'surfboard', 'kite', \
  37. 'pizza', 'cake', 'apple', 'banana', 'sandwich', 'orange', 'carrot', 'donut' ,\
  38. ]
  39. ### some of the entities are similar, map them to a single one ###
  40. syn_dict = {
  41. 'people':'people', 'man':'people', 'men':'people', 'woman':'people', 'women':'people', 'girl':'people', 'boy':'people', 'lady':'people', 'kid':'people', 'child':'people', 'children':'people', 'baby':'people', 'student':'people', 'bride':'people', 'groom':'people', 'couple':'people', 'prince':'people', 'princess':'people',\
  42. 'airplane': 'aeroplane','motorbike': 'motorcycle','bike': 'bicycle',\
  43. 'TV':'tv', 'desk': 'table', 'couch':'sofa',\
  44. 'building': 'house', 'hotel': 'house', \
  45. 'T-shirt': 'shirt','T-Shirt': 'shirt', 'handbag': 'bag', \
  46. }
  47. ### unique entities ###
  48. TOP_UNIQUE_CLASSES = [
  49. 'people', 'car', 'bus', 'truck', 'motorcycle', \
  50. 'train', 'bicycle', 'boat', 'aeroplane', 'cup', \
  51. 'bottle', 'bowl', 'knife', 'spoon', 'glass', \
  52. 'fork', 'chair', 'table', 'bench', 'clock', \
  53. 'laptop', 'light', 'vase', 'plant', 'remote',\
  54. 'microwave', 'toaster', 'oven','mouse', 'keyboard',\
  55. 'sofa', 'monitor', 'tv', 'flower','refrigerator', \
  56. 'house', 'bag', 'umbrella','book', 'backpack', \
  57. 'phone', 'shirt', 'tie', 'suitcase', 'box',\
  58. 'sink','bed','toilet', 'cat','dog', \
  59. 'horse', 'bird','cow', 'sheep' ,'elephant', \
  60. 'bear', 'zebra', 'giraffe', 'ball', 'racket', \
  61. 'skateboard', 'skis', 'snowboard', 'surfboard', 'kite',\
  62. 'pizza', 'cake', 'apple', 'banana', 'sandwich',\
  63. 'orange', 'carrot', 'donut' ,\
  64. ]
  65. TOP_UNIQUE_CLASSES_IDX = {}
  66. for i, x in enumerate(TOP_UNIQUE_CLASSES):
  67. TOP_UNIQUE_CLASSES_IDX[x] = i
  68. class ClipDataset(Dataset):
  69. def __init__(self, root_dir, meta_file, img_transform=None, text_transform=None,
  70. read_from='dir',
  71. label_texts_ensemble='none', split='train',
  72. cross_image=False, use_entity=True, mask_type='class', use_distilbert=True
  73. ):
  74. self.root_dir = root_dir
  75. self.meta_file = meta_file
  76. self.metas = pd.read_csv(self.meta_file)
  77. print(f'Total {len(self.metas)} samples')
  78. self.read_from = read_from
  79. if self.read_from == 'petrel':
  80. from petrel_client.client import Client
  81. self.client = Client()
  82. self.img_transform = img_transform
  83. self.text_transform = text_transform
  84. self.label_texts_ensemble = label_texts_ensemble
  85. self.split=split
  86. self.cross_image = cross_image
  87. self.use_entity = use_entity
  88. self.tokenizer = SimpleTokenizer()
  89. self.mask_type = mask_type
  90. self.use_distilbert = use_distilbert
  91. def __len__(self):
  92. return len(self.metas)
  93. def load_image(self, filename):
  94. filename = os.path.join(self.root_dir, filename)
  95. if self.read_from == 'dir':
  96. img = Image.open(filename).convert('RGB')
  97. return img
  98. elif self.read_from == 'petrel':
  99. value = self.client.get(filename)
  100. img_bytes = np.frombuffer(value, dtype=np.uint8)
  101. with Image.open(io.BytesIO(img_bytes)) as img:
  102. img = img.convert('RGB')
  103. return img
  104. else:
  105. raise NotImplementedError
  106. def _load_meta(self, idx):
  107. return self.metas.iloc[idx]
  108. def sample_cross_image(self, curr_meta):
  109. pair_index = curr_meta['pairindex']
  110. pair_entity = curr_meta['pairentity']
  111. pair_index_list = ast.literal_eval(pair_index)
  112. pair_entity_list = ast.literal_eval(pair_entity)
  113. sample_index = np.random.randint(0, len(pair_index_list))
  114. index = pair_index_list[sample_index]
  115. entity = pair_entity_list[sample_index]
  116. pair_meta = self._load_meta(index)
  117. img = self.load_image(pair_meta['image_id'])
  118. caption = pair_meta['caption']
  119. return img, caption, entity
  120. def __getitem__(self, idx):
  121. curr_meta = self._load_meta(idx)
  122. filename = curr_meta['image_id']
  123. raw_caption = caption = curr_meta['caption']
  124. label = int(curr_meta['label']) if 'label' in curr_meta else -1
  125. ret_info = {}
  126. try:
  127. assert self.is_contains_chinese(caption) == False
  128. img = self.load_image(filename)
  129. if self.img_transform is not None:
  130. image = self.img_transform(img)
  131. if self.text_transform is not None:
  132. if self.split == 'train':
  133. ### for clip TextTransformer, captions are here tokenised ###
  134. ### for bert/distilbert, text transform are used to select nouns, captions will be tokensized later ###
  135. caption, nouns, locs, prompt_texts = self.text_transform(caption)
  136. if self.use_entity:
  137. ### A feasible option here is to pre-process question and answers to speed-up data loading ###
  138. if self.use_distilbert:
  139. ### bert/distilbert-like, questions/answers will be tokenised later ###
  140. raw_question, question, raw_answer, answer = self.build_question_and_answer_for_distilbert(raw_caption, nouns)
  141. else:
  142. ### clip TextTransformer-like, questions/answers are tokenised ###
  143. raw_question, question, raw_answer, answer = self.build_question_and_answer(raw_caption, nouns)
  144. ret_info['question'] = question
  145. ret_info['answer'] = answer
  146. ret_info['raw_question'] = raw_question
  147. ret_info['raw_answer'] = raw_answer
  148. if self.cross_image:
  149. crossimg, crosscaption, crossentity = self.sample_cross_image(curr_meta)
  150. crossimg = self.img_transform(crossimg)
  151. crossentity = 'A photo of ' + crossentity
  152. ret_info['cross_image'] = crossimg
  153. ret_info['cross_entity'] = crossentity
  154. ret_info['cross_caption'] = crosscaption
  155. else:
  156. caption = self.text_transform(caption)
  157. ret_info['image'] = image
  158. ret_info['caption'] = caption
  159. ret_info['raw_caption'] = raw_caption
  160. ret_info['target'] = label
  161. # ret_info['filename'] = filename
  162. return ret_info
  163. except Exception as e:
  164. return self.__getitem__(np.random.randint(0, len(self.metas)))
  165. def judge_noun(self, n):
  166. n = n.replace('.', '')
  167. ans = n
  168. ### conduct Lemmatization ###
  169. ans = lemmatizer.lemmatize(ans.lower())
  170. if ans in syn_dict:
  171. ans = syn_dict[ans]
  172. if ans in TOP_UNIQUE_CLASSES:
  173. return 1, ans
  174. return 0, n
  175. def build_question_and_answer(self, caption, nouns):
  176. words = caption.split(' ')
  177. question = ''
  178. ans_list = []
  179. token_mapper = {}
  180. word_mapper = {}
  181. assert self.mask_type == 'class'
  182. for word in words:
  183. word = word.strip("'s").strip(' ').strip('\n')
  184. word_flag, newword = self.judge_noun(word)
  185. if word_flag == 1:
  186. question = question + newword + ' '
  187. ans_list.append(newword)
  188. token_id = self.tokenizer.encode(newword)[0]
  189. token_mapper[token_id] = TOP_UNIQUE_CLASSES_IDX[newword]
  190. word_mapper[token_id] = 332 ### this is 'M'
  191. else:
  192. question = question + word + ' '
  193. question = question.replace("'", '').strip()
  194. raw_question = question
  195. question, _, _, _ = self.text_transform(raw_question)
  196. question = torch.tensor([word_mapper[int(word)] if int(word) in word_mapper else word for word in question])
  197. # raw_answer = 'A photo of ' + ' and '.join(list(set(ans_list))) ## unique words
  198. raw_answer = random.choice(full_imagenet_templates).split('{}')[0] + ' and '.join(list(set(ans_list)))
  199. answer, _, _, _ = self.text_transform(raw_answer)
  200. return raw_question, question, raw_answer, answer
  201. def build_question_and_answer_for_distilbert(self, caption, nouns):
  202. words = caption.split(' ')
  203. question = ''
  204. entity_list = []
  205. ### default, mask all entites ###
  206. assert self.mask_type == 'class'
  207. for word in words:
  208. word = word.strip("'s").strip(' ').strip('\n')
  209. word_flag, newword = self.judge_noun(word)
  210. if word_flag == 1:
  211. question = question + '[MASK]' + ' '
  212. entity_list.append(newword)
  213. else:
  214. question = question + word + ' '
  215. question = question.replace("'", '').strip()
  216. raw_question = question
  217. #### build and transform answers ###
  218. # raw_answer = 'A photo of ' + ' and '.join(list(set(ans_list))) ## unique words
  219. raw_answer = random.choice(full_imagenet_templates).split('{}')[0] + ' and '.join(list(set(entity_list)))
  220. return raw_question, None, raw_answer, None
  221. def is_contains_chinese(self, strs):
  222. for _char in strs:
  223. if '\u4e00' <= _char <= '\u9fa5':
  224. return True
  225. return False