clip_dataset.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # -------------------------------------------------------------------------
  2. # Written by Jilan Xu
  3. # -------------------------------------------------------------------------
  4. from re import L
  5. import torch
  6. import json
  7. import os.path as osp
  8. import requests
  9. import numpy as np
  10. import time
  11. from typing import List
  12. from torch.utils.data import Dataset
  13. import random
  14. import os
  15. import omegaconf
  16. import clip
  17. from .tokenizer import SimpleTokenizer
  18. from .imagenet_template import full_imagenet_templates
  19. from nltk.stem import WordNetLemmatizer
  20. from PIL import Image
  21. import io
  22. lemmatizer = WordNetLemmatizer()
  23. ### frequently appeared 100 entities ###
  24. TOP_CLASSES_1=[
  25. 'people', 'man', 'men', 'woman', 'women', 'girl', 'boy', 'lady', 'kid', 'child', 'children', 'baby', 'student', 'bride', 'groom', 'couple', 'prince', 'princess', \
  26. 'car', 'bus', 'truck', 'motorcycle', 'train', 'bicycle', 'boat', 'aeroplane', 'airplane', 'motorbike', 'bike',\
  27. 'cup', 'bottle', 'bowl', 'knife', 'spoon', 'glass', 'fork',\
  28. 'chair', 'table', 'bench', 'clock', 'laptop', 'light', 'vase', 'plant', 'remote', 'microwave', 'toaster', 'oven','mouse', 'keyboard','sofa', 'monitor','desk', 'tv','TV', 'couch', 'flower','refrigerator', \
  29. 'house', 'building', 'hotel',\
  30. 'handbag', 'umbrella','book', 'backpack', 'phone', 'shirt', 'tie', 'suitcase','T-shirt', 'bag', 'box', \
  31. 'sink','bed','toilet',\
  32. 'cat','dog', 'horse', 'bird','cow', 'sheep' ,'elephant', 'bear', 'zebra', 'giraffe', \
  33. 'ball', 'racket', 'skateboard', 'skis', 'snowboard', 'surfboard', 'kite', \
  34. 'pizza', 'cake', 'apple', 'banana', 'sandwich', 'orange', 'carrot', 'donut' ,\
  35. ]
  36. ### some of the entities are similar, map them to a single one ###
  37. syn_dict = {
  38. 'people':'people', 'man':'people', 'men':'people', 'woman':'people', 'women':'people', 'girl':'people', 'boy':'people', 'lady':'people', 'kid':'people', 'child':'people', 'children':'people', 'baby':'people', 'student':'people', 'bride':'people', 'groom':'people', 'couple':'people', 'prince':'people', 'princess':'people',\
  39. 'airplane': 'aeroplane','motorbike': 'motorcycle','bike': 'bicycle',\
  40. 'TV':'tv', 'desk': 'table', 'couch':'sofa',\
  41. 'building': 'house', 'hotel': 'house', \
  42. 'T-shirt': 'shirt','T-Shirt': 'shirt', 'handbag': 'bag', \
  43. }
  44. ### unique entities ###
  45. TOP_UNIQUE_CLASSES = [
  46. 'people', 'car', 'bus', 'truck', 'motorcycle', \
  47. 'train', 'bicycle', 'boat', 'aeroplane', 'cup', \
  48. 'bottle', 'bowl', 'knife', 'spoon', 'glass', \
  49. 'fork', 'chair', 'table', 'bench', 'clock', \
  50. 'laptop', 'light', 'vase', 'plant', 'remote',\
  51. 'microwave', 'toaster', 'oven','mouse', 'keyboard',\
  52. 'sofa', 'monitor', 'tv', 'flower','refrigerator', \
  53. 'house', 'bag', 'umbrella','book', 'backpack', \
  54. 'phone', 'shirt', 'tie', 'suitcase', 'box',\
  55. 'sink','bed','toilet', 'cat','dog', \
  56. 'horse', 'bird','cow', 'sheep' ,'elephant', \
  57. 'bear', 'zebra', 'giraffe', 'ball', 'racket', \
  58. 'skateboard', 'skis', 'snowboard', 'surfboard', 'kite',\
  59. 'pizza', 'cake', 'apple', 'banana', 'sandwich',\
  60. 'orange', 'carrot', 'donut' ,\
  61. ]
  62. TOP_UNIQUE_CLASSES_IDX = {}
  63. for i, x in enumerate(TOP_UNIQUE_CLASSES):
  64. TOP_UNIQUE_CLASSES_IDX[x] = i
  65. class ClipDataset(Dataset):
  66. """
  67. Clip Dataset.
  68. Arguments:
  69. - root_dir (:obj:`str`): root directory of dataset
  70. - meta_file (:obj:`str`): name of meta file
  71. - transform (list of ``Transform`` objects): list of transforms
  72. - read_from (:obj:`str`): read type from the original meta_file
  73. - evaluator (:obj:`Evaluator`): evaluate to get metrics
  74. - image_reader_type (:obj:`str`): reader type 'pil' or 'ks'
  75. - osg_server (:obj:`str`): '10.198.3.28:30080/components/osg-default/v1'
  76. - topnoun: 'none' / 'coco_top50' / 'cc3m_top50' / ...
  77. Metafile example::
  78. "{"filename": "n01440764/n01440764_10026.JPEG", "label": 0, "label_name": "dog"}\n"
  79. """
  80. def __init__(self, root_dir, meta_file, img_transform=None, text_transform=None,
  81. read_from='dir', evaluator=None, image_reader_type='pil',
  82. fseek=False, label_texts_ensemble='none', split='train',
  83. cross_image=False, use_entity=True, mask_type='class', use_distilbert=True, class_label_dir=None, sample_list_dir=None,
  84. ):
  85. if not isinstance(meta_file, List) and not isinstance(meta_file, omegaconf.listconfig.ListConfig):
  86. meta_file = [meta_file]
  87. if not isinstance(root_dir, List) and not isinstance(meta_file, omegaconf.listconfig.ListConfig):
  88. root_dir = [root_dir]
  89. self.meta_file = meta_file
  90. self.root_dir = root_dir
  91. self.read_from = read_from
  92. if self.read_from == 'petrel':
  93. from petrel_client.client import Client
  94. self.client = Client()
  95. self.img_transform = img_transform
  96. self.text_transform = text_transform
  97. self.evaluator = evaluator
  98. self.fseek = fseek
  99. self.initialized = False
  100. self.label_texts_ensemble = label_texts_ensemble
  101. self.num = 0
  102. self.split=split
  103. self.cross_image = cross_image
  104. self.use_entity = use_entity
  105. self.tokenizer = SimpleTokenizer()
  106. self.mask_type = mask_type
  107. self.use_distilbert = use_distilbert
  108. if self.cross_image:
  109. self._load_meta_class_dict(class_label_dir, sample_list_dir)
  110. self.metas = []
  111. ### fseek uses file seek to load each line with pointer online ###
  112. ### this saves the memory while adding the loading time ###
  113. if self.fseek:
  114. self.line_offsets = []
  115. for each_meta_file in meta_file:
  116. line_offset = []
  117. offset = 0
  118. with open(each_meta_file) as f:
  119. for line in f:
  120. line_offset.append(offset)
  121. offset += len(line.encode('UTF-8'))
  122. f.close()
  123. self.num += len(line_offset)
  124. self.line_offsets.append(line_offset)
  125. else:
  126. ### read from local file and load all metafile info ###
  127. for rd, each_meta_file in zip(root_dir, meta_file):
  128. with open(each_meta_file) as f:
  129. lines = f.readlines()
  130. self.num += len(lines)
  131. for line in lines:
  132. info = json.loads(line)
  133. filename = osp.join(rd, info['filename'])
  134. ### add root_dir to filename ###
  135. info['filename'] = filename
  136. self.metas.append(info)
  137. def __len__(self):
  138. return self.num
  139. def _str2list(self, x):
  140. if type(x) is list:
  141. return x
  142. elif type(x) is str:
  143. return [x]
  144. else:
  145. raise RuntimeError(
  146. "unknown value for _str2list: {}".format(type(x)))
  147. def load_image(self, filename):
  148. if self.read_from == 'dir':
  149. img = Image.open(filename).convert('RGB')
  150. return img
  151. elif self.read_from == 'petrel':
  152. value = self.client.get(filename)
  153. img_bytes = np.frombuffer(value, dtype=np.uint8)
  154. with Image.open(io.BytesIO(img_bytes)) as img:
  155. img = img.convert('RGB')
  156. return img
  157. else:
  158. raise NotImplementedError
  159. def _load_meta(self, idx):
  160. if self.fseek:
  161. source_id = 0
  162. while idx >= len(self.line_offsets[source_id]):
  163. idx -= len(self.line_offsets[source_id])
  164. source_id += 1 #fixed
  165. with open(self.meta_file[source_id]) as f:
  166. f.seek(self.line_offsets[source_id][idx])
  167. line = f.readline()
  168. meta = json.loads(line)
  169. filename = osp.join(self.root_dir[source_id], meta['filename'])
  170. meta['filename'] = filename
  171. f.close()
  172. return meta
  173. else:
  174. return self.metas[idx]
  175. def _load_meta_class_dict(self, class_label_dir, sample_list_dir):
  176. # load class dict which is used to sample cross_image
  177. with open(sample_list_dir) as f:
  178. lines = f.readline()
  179. self.class_dict = json.loads(lines)
  180. # load class label for each sample
  181. with open(class_label_dir) as f:
  182. lines = f.readline()
  183. self.class_label = json.loads(lines)
  184. def sample_cross_image(self, curr_cls):
  185. class_list = self.class_dict[curr_cls]
  186. filename, caption = random.choice(class_list)
  187. # curr_meta = self._load_meta(idx)
  188. # filename = curr_meta['filename']
  189. filename = osp.join(self.root_dir[0], filename)
  190. curr_meta = {'filename':filename, 'caption':caption}
  191. img = self.load_image(filename)
  192. caption = curr_meta['caption'] if 'caption' in curr_meta else ''
  193. raw_caption = curr_meta['caption'] if 'caption' in curr_meta else ''
  194. caption, nouns, locs, _ = self.text_transform(caption)
  195. return img, caption, raw_caption
  196. def __getitem__(self, idx):
  197. curr_meta = self._load_meta(idx)
  198. filename = curr_meta['filename']
  199. label = int(curr_meta['label']) if 'label' in curr_meta else -1
  200. label_name = curr_meta['label_name'] if 'label_name' in curr_meta else None
  201. caption = curr_meta['caption'] if 'caption' in curr_meta else ''
  202. raw_caption = curr_meta['caption']
  203. tag = self._str2list(curr_meta['tag']) if 'tag' in curr_meta else []
  204. ret_info = {}
  205. # try:
  206. assert self.is_contains_chinese(caption) == False
  207. img = self.load_image(filename)
  208. if self.img_transform is not None:
  209. image = self.img_transform(img)
  210. if self.text_transform is not None:
  211. if self.split == 'train':
  212. ### for clip TextTransformer, captions are here tokenised ###
  213. ### for bert/distilbert, text transform are used to select nouns, captions will be tokensized later ###
  214. caption, nouns, locs, prompt_texts = self.text_transform(caption)
  215. if self.use_entity:
  216. if self.use_distilbert:
  217. ### bert/distilbert-like, questions/answers will be tokenised later ###
  218. raw_question, question, raw_answer, answer = self.build_question_and_answer_for_distilbert(raw_caption, nouns)
  219. else:
  220. ### clip TextTransformer-like, questions/answers are tokenised ###
  221. raw_question, question, raw_answer, answer = self.build_question_and_answer(raw_caption, nouns)
  222. ret_info['question'] = question
  223. ret_info['answer'] = answer
  224. ret_info['raw_question'] = raw_question
  225. ret_info['raw_answer'] = raw_answer
  226. if self.cross_image:
  227. imgname = filename.split('/')[-1]
  228. top100_label = self.class_label[imgname] # the label is str, due to some issues
  229. crossimg, crosscaption, cross_rawcaption = self.sample_cross_image(top100_label)
  230. # crossimg = tensor_trans(trans(crossimg))
  231. crossimg = self.img_transform(crossimg)
  232. cross_entity = 'A photo of ' + TOP_UNIQUE_CLASSES[int(top100_label)]
  233. ret_info['cross_image'] = crossimg
  234. ret_info['cross_entity'] = cross_entity
  235. else:
  236. caption = self.text_transform(caption)
  237. ret_info['image'] = image
  238. ret_info['caption'] = caption
  239. ret_info['target'] = label
  240. ret_info['raw_caption'] = raw_caption
  241. # ret_info['filename'] = filename
  242. return ret_info
  243. # except Exception as e:
  244. # return self.__getitem__(0)
  245. def judge_noun(self, n):
  246. n = n.replace('.', '')
  247. ans = n
  248. ### conduct Lemmatization ###
  249. ans = lemmatizer.lemmatize(ans.lower())
  250. if ans in syn_dict:
  251. ans = syn_dict[ans]
  252. if ans in TOP_UNIQUE_CLASSES:
  253. return 1, ans
  254. return 0, n
  255. def build_question_and_answer(self, caption, nouns):
  256. words = caption.split(' ')
  257. question = ''
  258. ans_list = []
  259. token_mapper = {}
  260. word_mapper = {}
  261. assert self.mask_type == 'class'
  262. for word in words:
  263. word_after = word
  264. word_flag, newword = self.judge_noun(word)
  265. if word_flag == 1:
  266. question = question + newword + ' '
  267. ans_list.append(newword)
  268. token_id = self.tokenizer.encode(newword)[0]
  269. token_mapper[token_id] = TOP_UNIQUE_CLASSES_IDX[newword]
  270. word_mapper[token_id] = 332 ### this is 'M'
  271. else:
  272. question = question + word + ' '
  273. question = question.replace("'", '').strip()
  274. raw_question = question
  275. question, _, _, _ = self.text_transform(raw_question)
  276. question = torch.tensor([word_mapper[int(word)] if int(word) in word_mapper else word for word in question])
  277. # raw_answer = 'A photo of ' + ' and '.join(list(set(ans_list))) ## unique words
  278. raw_answer = random.choice(full_imagenet_templates).split('{}')[0] + ' and '.join(list(set(ans_list)))
  279. answer, _, _, _ = self.text_transform(raw_answer)
  280. return raw_question, question, raw_answer, answer
  281. def build_question_and_answer_for_distilbert(self, caption, nouns):
  282. words = caption.split(' ')
  283. question = ''
  284. entity_list = []
  285. ### default, mask all entites ###
  286. assert self.mask_type == 'class'
  287. for word in words:
  288. word_after = word
  289. word_flag, newword = self.judge_noun(word)
  290. if word_flag == 1:
  291. question = question + '[MASK]' + ' '
  292. entity_list.append(newword)
  293. else:
  294. question = question + word + ' '
  295. question = question.replace("'", '').strip()
  296. raw_question = question
  297. #### build and transform answers ###
  298. # raw_answer = 'A photo of ' + ' and '.join(list(set(ans_list))) ## unique words
  299. raw_answer = random.choice(full_imagenet_templates).split('{}')[0] + ' and '.join(list(set(entity_list)))
  300. return raw_question, None, raw_answer, None
  301. def is_contains_chinese(self, strs):
  302. for _char in strs:
  303. if '\u4e00' <= _char <= '\u9fa5':
  304. return True
  305. return False