clip_dataset.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # -------------------------------------------------------------------------
  2. # Written by Jilan Xu
  3. # -------------------------------------------------------------------------
  4. from re import L
  5. import torch
  6. import json
  7. import os.path as osp
  8. import requests
  9. import numpy as np
  10. import time
  11. from typing import List
  12. from .base_dataset import BaseDataset
  13. # from prototype.data.image_reader import build_image_reader
  14. from .image_reader import build_image_reader
  15. # import linklink as link
  16. import random
  17. import os
  18. import omegaconf
  19. import clip
  20. from ipdb import set_trace
  21. from .tokenizer import SimpleTokenizer
  22. from .imagenet_template import full_imagenet_templates
  23. from nltk.stem import WordNetLemmatizer
  24. lemmatizer = WordNetLemmatizer()
  25. ### frequently appeared 100 entities ###
  26. TOP_CLASSES_1=[
  27. 'people', 'man', 'men', 'woman', 'women', 'girl', 'boy', 'lady', 'kid', 'child', 'children', 'baby', 'student', 'bride', 'groom', 'couple', 'prince', 'princess', \
  28. 'car', 'bus', 'truck', 'motorcycle', 'train', 'bicycle', 'boat', 'aeroplane', 'airplane', 'motorbike', 'bike',\
  29. 'cup', 'bottle', 'bowl', 'knife', 'spoon', 'glass', 'fork',\
  30. 'chair', 'table', 'bench', 'clock', 'laptop', 'light', 'vase', 'plant', 'remote', 'microwave', 'toaster', 'oven','mouse', 'keyboard','sofa', 'monitor','desk', 'tv','TV', 'couch', 'flower','refrigerator', \
  31. 'house', 'building', 'hotel',\
  32. 'handbag', 'umbrella','book', 'backpack', 'phone', 'shirt', 'tie', 'suitcase','T-shirt', 'bag', 'box', \
  33. 'sink','bed','toilet',\
  34. 'cat','dog', 'horse', 'bird','cow', 'sheep' ,'elephant', 'bear', 'zebra', 'giraffe', \
  35. 'ball', 'racket', 'skateboard', 'skis', 'snowboard', 'surfboard', 'kite', \
  36. 'pizza', 'cake', 'apple', 'banana', 'sandwich', 'orange', 'carrot', 'donut' ,\
  37. ]
  38. ### some of the entities are similar, map them to a single one ###
  39. syn_dict = {
  40. 'people':'people', 'man':'people', 'men':'people', 'woman':'people', 'women':'people', 'girl':'people', 'boy':'people', 'lady':'people', 'kid':'people', 'child':'people', 'children':'people', 'baby':'people', 'student':'people', 'bride':'people', 'groom':'people', 'couple':'people', 'prince':'people', 'princess':'people',\
  41. 'airplane': 'aeroplane','motorbike': 'motorcycle','bike': 'bicycle',\
  42. 'TV':'tv', 'desk': 'table', 'couch':'sofa',\
  43. 'building': 'house', 'hotel': 'house', \
  44. 'T-shirt': 'shirt','T-Shirt': 'shirt', 'handbag': 'bag', \
  45. }
  46. ### unique entities ###
  47. TOP_UNIQUE_CLASSES = [
  48. 'people', 'car', 'bus', 'truck', 'motorcycle', \
  49. 'train', 'bicycle', 'boat', 'aeroplane', 'cup', \
  50. 'bottle', 'bowl', 'knife', 'spoon', 'glass', \
  51. 'fork', 'chair', 'table', 'bench', 'clock', \
  52. 'laptop', 'light', 'vase', 'plant', 'remote',\
  53. 'microwave', 'toaster', 'oven','mouse', 'keyboard',\
  54. 'sofa', 'monitor', 'tv', 'flower','refrigerator', \
  55. 'house', 'bag', 'umbrella','book', 'backpack', \
  56. 'phone', 'shirt', 'tie', 'suitcase', 'box',\
  57. 'sink','bed','toilet', 'cat','dog', \
  58. 'horse', 'bird','cow', 'sheep' ,'elephant', \
  59. 'bear', 'zebra', 'giraffe', 'ball', 'racket', \
  60. 'skateboard', 'skis', 'snowboard', 'surfboard', 'kite',\
  61. 'pizza', 'cake', 'apple', 'banana', 'sandwich',\
  62. 'orange', 'carrot', 'donut' ,\
  63. ]
  64. TOP_UNIQUE_CLASSES_IDX = {}
  65. for i, x in enumerate(TOP_UNIQUE_CLASSES):
  66. TOP_UNIQUE_CLASSES_IDX[x] = i
  67. class ClipDataset(BaseDataset):
  68. """
  69. Clip Dataset.
  70. Arguments:
  71. - root_dir (:obj:`str`): root directory of dataset
  72. - meta_file (:obj:`str`): name of meta file
  73. - transform (list of ``Transform`` objects): list of transforms
  74. - read_from (:obj:`str`): read type from the original meta_file
  75. - evaluator (:obj:`Evaluator`): evaluate to get metrics
  76. - image_reader_type (:obj:`str`): reader type 'pil' or 'ks'
  77. - osg_server (:obj:`str`): '10.198.3.28:30080/components/osg-default/v1'
  78. - topnoun: 'none' / 'coco_top50' / 'cc3m_top50' / ...
  79. Metafile example::
  80. "{"filename": "n01440764/n01440764_10026.JPEG", "label": 0, "label_name": "dog"}\n"
  81. """
  82. def __init__(self, root_dir, meta_file, img_transform=None, text_transform=None,
  83. read_from='mc', evaluator=None, image_reader_type='pil',
  84. fseek=False, label_texts_ensemble='none', split='train',
  85. cross_image=False, use_entity=True, mask_type='class', use_distilbert=True, class_label_dir=None, sample_list_dir=None,
  86. ):
  87. if not isinstance(meta_file, List) and not isinstance(meta_file, omegaconf.listconfig.ListConfig):
  88. meta_file = [meta_file]
  89. if not isinstance(root_dir, List) and not isinstance(meta_file, omegaconf.listconfig.ListConfig):
  90. root_dir = [root_dir]
  91. self.meta_file = meta_file
  92. self.root_dir = root_dir
  93. self.read_from = read_from
  94. self.img_transform = img_transform
  95. self.text_transform = text_transform
  96. self.evaluator = evaluator
  97. self.image_reader = build_image_reader(image_reader_type)
  98. self.fseek = fseek
  99. self.initialized = False
  100. self.label_texts_ensemble = label_texts_ensemble
  101. self.num = 0
  102. self.split=split
  103. self.cross_image = cross_image
  104. self.use_entity = use_entity
  105. self.tokenizer = SimpleTokenizer()
  106. self.mask_type = mask_type
  107. self.use_distilbert = use_distilbert
  108. if self.cross_image:
  109. self._load_meta_class_dict(class_label_dir, sample_list_dir)
  110. self.metas = []
  111. ### fseek uses file seek to load each line with pointer online ###
  112. ### this saves the memory while adding the loading time ###
  113. if self.fseek:
  114. self.line_offsets = []
  115. for each_meta_file in meta_file:
  116. line_offset = []
  117. offset = 0
  118. with open(each_meta_file) as f:
  119. for line in f:
  120. line_offset.append(offset)
  121. offset += len(line.encode('UTF-8'))
  122. f.close()
  123. self.num += len(line_offset)
  124. self.line_offsets.append(line_offset)
  125. else:
  126. ### read from local file and load all metafile info ###
  127. for rd, each_meta_file in zip(root_dir, meta_file):
  128. with open(each_meta_file) as f:
  129. lines = f.readlines()
  130. self.num += len(lines)
  131. for line in lines:
  132. info = json.loads(line)
  133. filename = osp.join(rd, info['filename'])
  134. ### add root_dir to filename ###
  135. info['filename'] = filename
  136. self.metas.append(info)
  137. super(ClipDataset, self).__init__(root_dir=root_dir,
  138. meta_file=meta_file,
  139. read_from=read_from,
  140. transform=img_transform,
  141. evaluator=evaluator)
  142. def __len__(self):
  143. return self.num
  144. def _str2list(self, x):
  145. if type(x) is list:
  146. return x
  147. elif type(x) is str:
  148. return [x]
  149. else:
  150. raise RuntimeError(
  151. "unknown value for _str2list: {}".format(type(x)))
  152. def _load_meta(self, idx):
  153. if self.fseek:
  154. source_id = 0
  155. while idx >= len(self.line_offsets[source_id]):
  156. idx -= len(self.line_offsets[source_id])
  157. source_id += 1 #fixed
  158. with open(self.meta_file[source_id]) as f:
  159. f.seek(self.line_offsets[source_id][idx])
  160. line = f.readline()
  161. meta = json.loads(line)
  162. filename = osp.join(self.root_dir[source_id], meta['filename'])
  163. meta['filename'] = filename
  164. f.close()
  165. return meta
  166. else:
  167. return self.metas[idx]
  168. def _load_meta_class_dict(self, class_label_dir, sample_list_dir):
  169. # load class dict which is used to sample cross_image
  170. with open(sample_list_dir) as f:
  171. lines = f.readline()
  172. self.class_dict = json.loads(lines)
  173. # load class label for each sample
  174. with open(class_label_dir) as f:
  175. lines = f.readline()
  176. self.class_label = json.loads(lines)
  177. def sample_cross_image(self, curr_cls):
  178. class_list = self.class_dict[curr_cls]
  179. filename, caption = random.choice(class_list)
  180. # curr_meta = self._load_meta(idx)
  181. # filename = curr_meta['filename']
  182. filename = osp.join(self.root_dir[0], filename)
  183. curr_meta = {'filename':filename, 'caption':caption}
  184. img_bytes = self.read_file(curr_meta)
  185. img = self.image_reader(img_bytes, filename)
  186. caption = curr_meta['caption'] if 'caption' in curr_meta else ''
  187. raw_caption = curr_meta['caption'] if 'caption' in curr_meta else ''
  188. caption, nouns, locs, _ = self.text_transform(caption)
  189. return img, caption, raw_caption
  190. def __getitem__(self, idx):
  191. curr_meta = self._load_meta(idx)
  192. filename = curr_meta['filename']
  193. label = int(curr_meta['label']) if 'label' in curr_meta else -1
  194. label_name = curr_meta['label_name'] if 'label_name' in curr_meta else None
  195. caption = curr_meta['caption'] if 'caption' in curr_meta else ''
  196. raw_caption = curr_meta['caption']
  197. tag = self._str2list(curr_meta['tag']) if 'tag' in curr_meta else []
  198. ret_info = {}
  199. #############
  200. try:
  201. assert self.is_contains_chinese(caption) == False
  202. img_bytes = self.read_file(curr_meta)
  203. img = self.image_reader(img_bytes, filename)
  204. if self.img_transform is not None:
  205. image = self.img_transform(img)
  206. if self.text_transform is not None:
  207. if self.split == 'train':
  208. ### for clip TextTransformer, captions are here tokenised ###
  209. ### for bert/distilbert, text transform are used to select nouns, captions will be tokensized later ###
  210. caption, nouns, locs, prompt_texts = self.text_transform(caption)
  211. if self.use_entity:
  212. if self.use_distilbert:
  213. ### bert/distilbert-like, questions/answers will be tokenised later ###
  214. raw_question, question, raw_answer, answer = self.build_question_and_answer_for_distilbert(raw_caption, nouns)
  215. else:
  216. ### clip TextTransformer-like, questions/answers are tokenised ###
  217. raw_question, question, raw_answer, answer = self.build_question_and_answer(raw_caption, nouns)
  218. ret_info['question'] = question
  219. ret_info['answer'] = answer
  220. ret_info['raw_question'] = raw_question
  221. ret_info['raw_answer'] = raw_answer
  222. if self.cross_image:
  223. imgname = filename.split('/')[-1]
  224. top100_label = self.class_label[imgname] # the label is str, due to some issues
  225. crossimg, crosscaption, cross_rawcaption = self.sample_cross_image(top100_label)
  226. # crossimg = tensor_trans(trans(crossimg))
  227. crossimg = self.img_transform(crossimg)
  228. cross_entity = 'A photo of ' + TOP_UNIQUE_CLASSES[int(top100_label)]
  229. ret_info['cross_image'] = crossimg
  230. ret_info['cross_entity'] = cross_entity
  231. else:
  232. caption = self.text_transform(caption)
  233. ret_info['image'] = image
  234. ret_info['caption'] = caption
  235. ret_info['target'] = label
  236. ret_info['raw_caption'] = raw_caption
  237. # ret_info['filename'] = filename
  238. return ret_info
  239. except Exception as e:
  240. print(e)
  241. # return self.__getitem__(0)
  242. # def judge_noun(self, n):
  243. # n = n.replace('.', '')
  244. # ans = n.split("'s")[0].split(',')[0]
  245. # ### conduct Lemmatization ###
  246. # # ans = nlp(ans)[0].lemma_
  247. # if ans in syn_dict:
  248. # ans = syn_dict[ans]
  249. # elif len(ans) >= 2 and ans[-2:] == 'es' and ans[:-2] in syn_dict:
  250. # ans = syn_dict[ans[:-2]]
  251. # elif len(ans) >= 1 and ans[-1] == 's' and ans[:-1] in syn_dict:
  252. # ans = syn_dict[ans[:-1]]
  253. # elif ans.lower() in syn_dict:
  254. # ans = syn_dict[ans.lower()]
  255. # elif len(ans) >= 2 and ans[-2:] == 'es' and ans.lower()[:-2] in syn_dict:
  256. # ans = syn_dict[ans.lower()[:-2]]
  257. # elif len(ans) >= 1 and ans[-1] == 's' and ans.lower()[:-1] in syn_dict:
  258. # ans = syn_dict[ans.lower()[:-1]]
  259. # if ans in TOP_UNIQUE_CLASSES:
  260. # return 1, ans
  261. # elif len(ans) >= 2 and ans[-2:] == 'es' and ans[:-2] in TOP_UNIQUE_CLASSES:
  262. # return 1, ans[:-2]
  263. # elif len(ans) >= 1 and ans[-1] == 's' and ans[:-1] in TOP_UNIQUE_CLASSES:
  264. # return 1, ans[:-1]
  265. # elif ans.lower() in TOP_UNIQUE_CLASSES:
  266. # return 1, ans.lower()
  267. # elif len(ans) >= 2 and ans.lower()[-2:] == 'es' and ans.lower()[:-2] in TOP_UNIQUE_CLASSES:
  268. # return 1, ans.lower()[:-2]
  269. # elif len(ans) >= 1 and ans.lower()[-1] == 's' and ans.lower()[:-1] in TOP_UNIQUE_CLASSES:
  270. # return 1, ans.lower()[:-1]
  271. # return 0, n
  272. def judge_noun(self, n):
  273. n = n.replace('.', '')
  274. # ans = n.split("'s")[0].split(',')[0]
  275. # ans = n.strip("'s").strip(",")
  276. ans = n
  277. ### conduct Lemmatization ###
  278. # ans = nlp(ans.lower())[0].lemma_
  279. ans = lemmatizer.lemmatize(ans.lower())
  280. if ans in syn_dict:
  281. ans = syn_dict[ans]
  282. if ans in TOP_UNIQUE_CLASSES:
  283. return 1, ans
  284. return 0, n
  285. def build_question_and_answer(self, caption, nouns):
  286. words = caption.split(' ')
  287. question = ''
  288. ans_list = []
  289. token_mapper = {}
  290. word_mapper = {}
  291. assert self.mask_type == 'class'
  292. for word in words:
  293. word_after = word
  294. word_flag, newword = self.judge_noun(word)
  295. if word_flag == 1:
  296. question = question + newword + ' '
  297. ans_list.append(newword)
  298. token_id = self.tokenizer.encode(newword)[0]
  299. token_mapper[token_id] = TOP_UNIQUE_CLASSES_IDX[newword]
  300. word_mapper[token_id] = 332 ### this is 'M'
  301. else:
  302. question = question + word + ' '
  303. question = question.replace("'", '').strip()
  304. raw_question = question
  305. question, _, _, _ = self.text_transform(raw_question)
  306. question = torch.tensor([word_mapper[int(word)] if int(word) in word_mapper else word for word in question])
  307. # raw_answer = 'A photo of ' + ' and '.join(list(set(ans_list))) ## unique words
  308. raw_answer = random.choice(full_imagenet_templates).split('{}')[0] + ' and '.join(list(set(ans_list)))
  309. answer, _, _, _ = self.text_transform(raw_answer)
  310. return raw_question, question, raw_answer, answer
  311. def build_question_and_answer_for_distilbert(self, caption, nouns):
  312. words = caption.split(' ')
  313. question = ''
  314. entity_list = []
  315. ### default, mask all entites ###
  316. assert self.mask_type == 'class'
  317. for word in words:
  318. word_after = word
  319. word_flag, newword = self.judge_noun(word)
  320. if word_flag == 1:
  321. question = question + '[MASK]' + ' '
  322. entity_list.append(newword)
  323. else:
  324. question = question + word + ' '
  325. question = question.replace("'", '').strip()
  326. raw_question = question
  327. #### build and transform answers ###
  328. # raw_answer = 'A photo of ' + ' and '.join(list(set(ans_list))) ## unique words
  329. raw_answer = random.choice(full_imagenet_templates).split('{}')[0] + ' and '.join(list(set(entity_list)))
  330. return raw_question, None, raw_answer, None
  331. def is_contains_chinese(self, strs):
  332. for _char in strs:
  333. if '\u4e00' <= _char <= '\u9fa5':
  334. return True
  335. return False
  336. def _get_label_text(self, text):
  337. # label_text = ['a photo of ' + text + '.']
  338. if self.label_texts_ensemble == 'prompt6':
  339. f = f'{osp.abspath(os.getcwd())}/../../prototype/data/datasets/prompts/query_pattern_prompt6'
  340. elif self.label_texts_ensemble == 'prompt8':
  341. f = f'{osp.abspath(os.getcwd())}/../../prototype/data/datasets/prompts/query_pattern_prompt8'
  342. elif self.label_texts_ensemble == 'prompt80':
  343. f = f'{osp.abspath(os.getcwd())}/../../prototype/data/datasets/prompts/query_pattern_prompt80'
  344. elif self.label_texts_ensemble == 'cc':
  345. return [text]
  346. else:
  347. f = f'{osp.abspath(os.getcwd())}/../../prototype/data/datasets/prompts/query_pattern_prompt1'
  348. label_text = []
  349. with open(f) as fin:
  350. for line in fin.readlines():
  351. label_text.append(line.replace('{0}', text))
  352. return label_text
  353. def get_label_texts(self,):
  354. label_to_name = {}
  355. for curr_meta in self.metas:
  356. label = int(curr_meta['label']) if 'label' in curr_meta else None
  357. label_name = curr_meta['label_name'] if 'label_name' in curr_meta else None
  358. if label is not None and label_name is not None:
  359. label_to_name[label] = label_name
  360. labels = list(label_to_name.keys())
  361. labels.sort()
  362. label_texts = []
  363. label_text_len = []
  364. for label in labels:
  365. label_name = label_to_name[label]
  366. label_text = self._get_label_text(label_name)
  367. label_texts.extend(label_text)
  368. label_text_len.append(len(label_text))
  369. all_len = sum(label_text_len)
  370. offset = 0
  371. label_num = len(labels)
  372. label_texts_ensemble_matrix = torch.zeros(all_len, label_num)
  373. for lbl, ltl in enumerate(label_text_len):
  374. label_texts_ensemble_matrix[offset: offset + ltl, lbl] = 1
  375. offset += ltl
  376. return label_texts, label_texts_ensemble_matrix
  377. def dump(self, writer, output):
  378. filenames = output['filenames']
  379. image_ids = output['image_ids']
  380. label_names = output['label_names']
  381. captions = output['captions']
  382. tags = output['tags']
  383. prediction = self.tensor2numpy(output['prediction'])
  384. score = self.tensor2numpy(output['score'])
  385. labels = self.tensor2numpy(output['labels'])
  386. for _idx in range(len(filenames)):
  387. res = {
  388. 'image_id': int(image_ids[_idx]),
  389. 'filename': filenames[_idx],
  390. 'label': int(labels[_idx]),
  391. 'label_name': label_names[_idx],
  392. 'caption': captions[_idx],
  393. 'tag': tags[_idx],
  394. 'prediction': int(prediction[_idx]),
  395. 'score': [float('%.8f' % s) for s in score[_idx]]
  396. }
  397. writer.write(json.dumps(res, ensure_ascii=False) + '\n')
  398. writer.flush()