clip_dataset.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. # -------------------------------------------------------------------------
  2. # Written by Jilan Xu
  3. # -------------------------------------------------------------------------
  4. from re import L
  5. import torch
  6. import json
  7. import os.path as osp
  8. import requests
  9. import numpy as np
  10. import time
  11. from typing import List
  12. from .base_dataset import BaseDataset
  13. # from prototype.data.image_reader import build_image_reader
  14. from .image_reader import build_image_reader
  15. # import linklink as link
  16. import random
  17. import os
  18. import omegaconf
  19. import clip
  20. from ipdb import set_trace
  21. from .tokenizer import SimpleTokenizer
  22. from .imagenet_template import full_imagenet_templates
  23. from nltk.stem import WordNetLemmatizer
  24. from PIL import Image
  25. lemmatizer = WordNetLemmatizer()
  26. ### frequently appeared 100 entities ###
  27. TOP_CLASSES_1=[
  28. 'people', 'man', 'men', 'woman', 'women', 'girl', 'boy', 'lady', 'kid', 'child', 'children', 'baby', 'student', 'bride', 'groom', 'couple', 'prince', 'princess', \
  29. 'car', 'bus', 'truck', 'motorcycle', 'train', 'bicycle', 'boat', 'aeroplane', 'airplane', 'motorbike', 'bike',\
  30. 'cup', 'bottle', 'bowl', 'knife', 'spoon', 'glass', 'fork',\
  31. 'chair', 'table', 'bench', 'clock', 'laptop', 'light', 'vase', 'plant', 'remote', 'microwave', 'toaster', 'oven','mouse', 'keyboard','sofa', 'monitor','desk', 'tv','TV', 'couch', 'flower','refrigerator', \
  32. 'house', 'building', 'hotel',\
  33. 'handbag', 'umbrella','book', 'backpack', 'phone', 'shirt', 'tie', 'suitcase','T-shirt', 'bag', 'box', \
  34. 'sink','bed','toilet',\
  35. 'cat','dog', 'horse', 'bird','cow', 'sheep' ,'elephant', 'bear', 'zebra', 'giraffe', \
  36. 'ball', 'racket', 'skateboard', 'skis', 'snowboard', 'surfboard', 'kite', \
  37. 'pizza', 'cake', 'apple', 'banana', 'sandwich', 'orange', 'carrot', 'donut' ,\
  38. ]
  39. ### some of the entities are similar, map them to a single one ###
  40. syn_dict = {
  41. 'people':'people', 'man':'people', 'men':'people', 'woman':'people', 'women':'people', 'girl':'people', 'boy':'people', 'lady':'people', 'kid':'people', 'child':'people', 'children':'people', 'baby':'people', 'student':'people', 'bride':'people', 'groom':'people', 'couple':'people', 'prince':'people', 'princess':'people',\
  42. 'airplane': 'aeroplane','motorbike': 'motorcycle','bike': 'bicycle',\
  43. 'TV':'tv', 'desk': 'table', 'couch':'sofa',\
  44. 'building': 'house', 'hotel': 'house', \
  45. 'T-shirt': 'shirt','T-Shirt': 'shirt', 'handbag': 'bag', \
  46. }
  47. ### unique entities ###
  48. TOP_UNIQUE_CLASSES = [
  49. 'people', 'car', 'bus', 'truck', 'motorcycle', \
  50. 'train', 'bicycle', 'boat', 'aeroplane', 'cup', \
  51. 'bottle', 'bowl', 'knife', 'spoon', 'glass', \
  52. 'fork', 'chair', 'table', 'bench', 'clock', \
  53. 'laptop', 'light', 'vase', 'plant', 'remote',\
  54. 'microwave', 'toaster', 'oven','mouse', 'keyboard',\
  55. 'sofa', 'monitor', 'tv', 'flower','refrigerator', \
  56. 'house', 'bag', 'umbrella','book', 'backpack', \
  57. 'phone', 'shirt', 'tie', 'suitcase', 'box',\
  58. 'sink','bed','toilet', 'cat','dog', \
  59. 'horse', 'bird','cow', 'sheep' ,'elephant', \
  60. 'bear', 'zebra', 'giraffe', 'ball', 'racket', \
  61. 'skateboard', 'skis', 'snowboard', 'surfboard', 'kite',\
  62. 'pizza', 'cake', 'apple', 'banana', 'sandwich',\
  63. 'orange', 'carrot', 'donut' ,\
  64. ]
  65. TOP_UNIQUE_CLASSES_IDX = {}
  66. for i, x in enumerate(TOP_UNIQUE_CLASSES):
  67. TOP_UNIQUE_CLASSES_IDX[x] = i
  68. class ClipDataset(BaseDataset):
  69. """
  70. Clip Dataset.
  71. Arguments:
  72. - root_dir (:obj:`str`): root directory of dataset
  73. - meta_file (:obj:`str`): name of meta file
  74. - transform (list of ``Transform`` objects): list of transforms
  75. - read_from (:obj:`str`): read type from the original meta_file
  76. - evaluator (:obj:`Evaluator`): evaluate to get metrics
  77. - image_reader_type (:obj:`str`): reader type 'pil' or 'ks'
  78. - osg_server (:obj:`str`): '10.198.3.28:30080/components/osg-default/v1'
  79. - topnoun: 'none' / 'coco_top50' / 'cc3m_top50' / ...
  80. Metafile example::
  81. "{"filename": "n01440764/n01440764_10026.JPEG", "label": 0, "label_name": "dog"}\n"
  82. """
  83. def __init__(self, root_dir, meta_file, img_transform=None, text_transform=None,
  84. read_from='dir', evaluator=None, image_reader_type='pil',
  85. fseek=False, label_texts_ensemble='none', split='train',
  86. cross_image=False, use_entity=True, mask_type='class', use_distilbert=True, class_label_dir=None, sample_list_dir=None,
  87. ):
  88. if not isinstance(meta_file, List) and not isinstance(meta_file, omegaconf.listconfig.ListConfig):
  89. meta_file = [meta_file]
  90. if not isinstance(root_dir, List) and not isinstance(meta_file, omegaconf.listconfig.ListConfig):
  91. root_dir = [root_dir]
  92. self.meta_file = meta_file
  93. self.root_dir = root_dir
  94. self.read_from = read_from
  95. self.img_transform = img_transform
  96. self.text_transform = text_transform
  97. self.evaluator = evaluator
  98. self.image_reader = build_image_reader(image_reader_type)
  99. self.fseek = fseek
  100. self.initialized = False
  101. self.label_texts_ensemble = label_texts_ensemble
  102. self.num = 0
  103. self.split=split
  104. self.cross_image = cross_image
  105. self.use_entity = use_entity
  106. self.tokenizer = SimpleTokenizer()
  107. self.mask_type = mask_type
  108. self.use_distilbert = use_distilbert
  109. if self.cross_image:
  110. self._load_meta_class_dict(class_label_dir, sample_list_dir)
  111. self.metas = []
  112. ### fseek uses file seek to load each line with pointer online ###
  113. ### this saves the memory while adding the loading time ###
  114. if self.fseek:
  115. self.line_offsets = []
  116. for each_meta_file in meta_file:
  117. line_offset = []
  118. offset = 0
  119. with open(each_meta_file) as f:
  120. for line in f:
  121. line_offset.append(offset)
  122. offset += len(line.encode('UTF-8'))
  123. f.close()
  124. self.num += len(line_offset)
  125. self.line_offsets.append(line_offset)
  126. else:
  127. ### read from local file and load all metafile info ###
  128. for rd, each_meta_file in zip(root_dir, meta_file):
  129. with open(each_meta_file) as f:
  130. lines = f.readlines()
  131. self.num += len(lines)
  132. for line in lines:
  133. info = json.loads(line)
  134. filename = osp.join(rd, info['filename'])
  135. ### add root_dir to filename ###
  136. info['filename'] = filename
  137. self.metas.append(info)
  138. super(ClipDataset, self).__init__(root_dir=root_dir,
  139. meta_file=meta_file,
  140. read_from=read_from,
  141. transform=img_transform,
  142. evaluator=evaluator)
  143. def __len__(self):
  144. return self.num
  145. def _str2list(self, x):
  146. if type(x) is list:
  147. return x
  148. elif type(x) is str:
  149. return [x]
  150. else:
  151. raise RuntimeError(
  152. "unknown value for _str2list: {}".format(type(x)))
  153. def _load_meta(self, idx):
  154. if self.fseek:
  155. source_id = 0
  156. while idx >= len(self.line_offsets[source_id]):
  157. idx -= len(self.line_offsets[source_id])
  158. source_id += 1 #fixed
  159. with open(self.meta_file[source_id]) as f:
  160. f.seek(self.line_offsets[source_id][idx])
  161. line = f.readline()
  162. meta = json.loads(line)
  163. filename = osp.join(self.root_dir[source_id], meta['filename'])
  164. meta['filename'] = filename
  165. f.close()
  166. return meta
  167. else:
  168. return self.metas[idx]
  169. def _load_meta_class_dict(self, class_label_dir, sample_list_dir):
  170. # load class dict which is used to sample cross_image
  171. with open(sample_list_dir) as f:
  172. lines = f.readline()
  173. self.class_dict = json.loads(lines)
  174. # load class label for each sample
  175. with open(class_label_dir) as f:
  176. lines = f.readline()
  177. self.class_label = json.loads(lines)
  178. def sample_cross_image(self, curr_cls):
  179. class_list = self.class_dict[curr_cls]
  180. filename, caption = random.choice(class_list)
  181. # curr_meta = self._load_meta(idx)
  182. # filename = curr_meta['filename']
  183. filename = osp.join(self.root_dir[0], filename)
  184. curr_meta = {'filename':filename, 'caption':caption}
  185. if self.read_from == 'dir':
  186. ### load via dir ###
  187. img = Image.open(filename).convert('RGB')
  188. else:
  189. ### load via bytes ###
  190. img_bytes = self.read_file(curr_meta)
  191. img = self.image_reader(img_bytes, filename)
  192. caption = curr_meta['caption'] if 'caption' in curr_meta else ''
  193. raw_caption = curr_meta['caption'] if 'caption' in curr_meta else ''
  194. caption, nouns, locs, _ = self.text_transform(caption)
  195. return img, caption, raw_caption
  196. def __getitem__(self, idx):
  197. curr_meta = self._load_meta(idx)
  198. filename = curr_meta['filename']
  199. label = int(curr_meta['label']) if 'label' in curr_meta else -1
  200. label_name = curr_meta['label_name'] if 'label_name' in curr_meta else None
  201. caption = curr_meta['caption'] if 'caption' in curr_meta else ''
  202. raw_caption = curr_meta['caption']
  203. tag = self._str2list(curr_meta['tag']) if 'tag' in curr_meta else []
  204. ret_info = {}
  205. #############
  206. try:
  207. assert self.is_contains_chinese(caption) == False
  208. if self.read_from == 'dir':
  209. ### load from dir ###
  210. img = Image.open(filename).convert('RGB')
  211. else:
  212. ### load from bytes ###
  213. img_bytes = self.read_file(curr_meta)
  214. img = self.image_reader(img_bytes, filename)
  215. if self.img_transform is not None:
  216. image = self.img_transform(img)
  217. if self.text_transform is not None:
  218. if self.split == 'train':
  219. ### for clip TextTransformer, captions are here tokenised ###
  220. ### for bert/distilbert, text transform are used to select nouns, captions will be tokensized later ###
  221. caption, nouns, locs, prompt_texts = self.text_transform(caption)
  222. if self.use_entity:
  223. if self.use_distilbert:
  224. ### bert/distilbert-like, questions/answers will be tokenised later ###
  225. raw_question, question, raw_answer, answer = self.build_question_and_answer_for_distilbert(raw_caption, nouns)
  226. else:
  227. ### clip TextTransformer-like, questions/answers are tokenised ###
  228. raw_question, question, raw_answer, answer = self.build_question_and_answer(raw_caption, nouns)
  229. ret_info['question'] = question
  230. ret_info['answer'] = answer
  231. ret_info['raw_question'] = raw_question
  232. ret_info['raw_answer'] = raw_answer
  233. if self.cross_image:
  234. imgname = filename.split('/')[-1]
  235. top100_label = self.class_label[imgname] # the label is str, due to some issues
  236. crossimg, crosscaption, cross_rawcaption = self.sample_cross_image(top100_label)
  237. # crossimg = tensor_trans(trans(crossimg))
  238. crossimg = self.img_transform(crossimg)
  239. cross_entity = 'A photo of ' + TOP_UNIQUE_CLASSES[int(top100_label)]
  240. ret_info['cross_image'] = crossimg
  241. ret_info['cross_entity'] = cross_entity
  242. else:
  243. caption = self.text_transform(caption)
  244. ret_info['image'] = image
  245. ret_info['caption'] = caption
  246. ret_info['target'] = label
  247. ret_info['raw_caption'] = raw_caption
  248. # ret_info['filename'] = filename
  249. return ret_info
  250. except Exception as e:
  251. print(e)
  252. # return self.__getitem__(0)
  253. # def judge_noun(self, n):
  254. # n = n.replace('.', '')
  255. # ans = n.split("'s")[0].split(',')[0]
  256. # ### conduct Lemmatization ###
  257. # # ans = nlp(ans)[0].lemma_
  258. # if ans in syn_dict:
  259. # ans = syn_dict[ans]
  260. # elif len(ans) >= 2 and ans[-2:] == 'es' and ans[:-2] in syn_dict:
  261. # ans = syn_dict[ans[:-2]]
  262. # elif len(ans) >= 1 and ans[-1] == 's' and ans[:-1] in syn_dict:
  263. # ans = syn_dict[ans[:-1]]
  264. # elif ans.lower() in syn_dict:
  265. # ans = syn_dict[ans.lower()]
  266. # elif len(ans) >= 2 and ans[-2:] == 'es' and ans.lower()[:-2] in syn_dict:
  267. # ans = syn_dict[ans.lower()[:-2]]
  268. # elif len(ans) >= 1 and ans[-1] == 's' and ans.lower()[:-1] in syn_dict:
  269. # ans = syn_dict[ans.lower()[:-1]]
  270. # if ans in TOP_UNIQUE_CLASSES:
  271. # return 1, ans
  272. # elif len(ans) >= 2 and ans[-2:] == 'es' and ans[:-2] in TOP_UNIQUE_CLASSES:
  273. # return 1, ans[:-2]
  274. # elif len(ans) >= 1 and ans[-1] == 's' and ans[:-1] in TOP_UNIQUE_CLASSES:
  275. # return 1, ans[:-1]
  276. # elif ans.lower() in TOP_UNIQUE_CLASSES:
  277. # return 1, ans.lower()
  278. # elif len(ans) >= 2 and ans.lower()[-2:] == 'es' and ans.lower()[:-2] in TOP_UNIQUE_CLASSES:
  279. # return 1, ans.lower()[:-2]
  280. # elif len(ans) >= 1 and ans.lower()[-1] == 's' and ans.lower()[:-1] in TOP_UNIQUE_CLASSES:
  281. # return 1, ans.lower()[:-1]
  282. # return 0, n
  283. def judge_noun(self, n):
  284. n = n.replace('.', '')
  285. # ans = n.split("'s")[0].split(',')[0]
  286. # ans = n.strip("'s").strip(",")
  287. ans = n
  288. ### conduct Lemmatization ###
  289. # ans = nlp(ans.lower())[0].lemma_
  290. ans = lemmatizer.lemmatize(ans.lower())
  291. if ans in syn_dict:
  292. ans = syn_dict[ans]
  293. if ans in TOP_UNIQUE_CLASSES:
  294. return 1, ans
  295. return 0, n
  296. def build_question_and_answer(self, caption, nouns):
  297. words = caption.split(' ')
  298. question = ''
  299. ans_list = []
  300. token_mapper = {}
  301. word_mapper = {}
  302. assert self.mask_type == 'class'
  303. for word in words:
  304. word_after = word
  305. word_flag, newword = self.judge_noun(word)
  306. if word_flag == 1:
  307. question = question + newword + ' '
  308. ans_list.append(newword)
  309. token_id = self.tokenizer.encode(newword)[0]
  310. token_mapper[token_id] = TOP_UNIQUE_CLASSES_IDX[newword]
  311. word_mapper[token_id] = 332 ### this is 'M'
  312. else:
  313. question = question + word + ' '
  314. question = question.replace("'", '').strip()
  315. raw_question = question
  316. question, _, _, _ = self.text_transform(raw_question)
  317. question = torch.tensor([word_mapper[int(word)] if int(word) in word_mapper else word for word in question])
  318. # raw_answer = 'A photo of ' + ' and '.join(list(set(ans_list))) ## unique words
  319. raw_answer = random.choice(full_imagenet_templates).split('{}')[0] + ' and '.join(list(set(ans_list)))
  320. answer, _, _, _ = self.text_transform(raw_answer)
  321. return raw_question, question, raw_answer, answer
  322. def build_question_and_answer_for_distilbert(self, caption, nouns):
  323. words = caption.split(' ')
  324. question = ''
  325. entity_list = []
  326. ### default, mask all entites ###
  327. assert self.mask_type == 'class'
  328. for word in words:
  329. word_after = word
  330. word_flag, newword = self.judge_noun(word)
  331. if word_flag == 1:
  332. question = question + '[MASK]' + ' '
  333. entity_list.append(newword)
  334. else:
  335. question = question + word + ' '
  336. question = question.replace("'", '').strip()
  337. raw_question = question
  338. #### build and transform answers ###
  339. # raw_answer = 'A photo of ' + ' and '.join(list(set(ans_list))) ## unique words
  340. raw_answer = random.choice(full_imagenet_templates).split('{}')[0] + ' and '.join(list(set(entity_list)))
  341. return raw_question, None, raw_answer, None
  342. def is_contains_chinese(self, strs):
  343. for _char in strs:
  344. if '\u4e00' <= _char <= '\u9fa5':
  345. return True
  346. return False
  347. def _get_label_text(self, text):
  348. # label_text = ['a photo of ' + text + '.']
  349. if self.label_texts_ensemble == 'prompt6':
  350. f = f'{osp.abspath(os.getcwd())}/../../prototype/data/datasets/prompts/query_pattern_prompt6'
  351. elif self.label_texts_ensemble == 'prompt8':
  352. f = f'{osp.abspath(os.getcwd())}/../../prototype/data/datasets/prompts/query_pattern_prompt8'
  353. elif self.label_texts_ensemble == 'prompt80':
  354. f = f'{osp.abspath(os.getcwd())}/../../prototype/data/datasets/prompts/query_pattern_prompt80'
  355. elif self.label_texts_ensemble == 'cc':
  356. return [text]
  357. else:
  358. f = f'{osp.abspath(os.getcwd())}/../../prototype/data/datasets/prompts/query_pattern_prompt1'
  359. label_text = []
  360. with open(f) as fin:
  361. for line in fin.readlines():
  362. label_text.append(line.replace('{0}', text))
  363. return label_text
  364. def get_label_texts(self,):
  365. label_to_name = {}
  366. for curr_meta in self.metas:
  367. label = int(curr_meta['label']) if 'label' in curr_meta else None
  368. label_name = curr_meta['label_name'] if 'label_name' in curr_meta else None
  369. if label is not None and label_name is not None:
  370. label_to_name[label] = label_name
  371. labels = list(label_to_name.keys())
  372. labels.sort()
  373. label_texts = []
  374. label_text_len = []
  375. for label in labels:
  376. label_name = label_to_name[label]
  377. label_text = self._get_label_text(label_name)
  378. label_texts.extend(label_text)
  379. label_text_len.append(len(label_text))
  380. all_len = sum(label_text_len)
  381. offset = 0
  382. label_num = len(labels)
  383. label_texts_ensemble_matrix = torch.zeros(all_len, label_num)
  384. for lbl, ltl in enumerate(label_text_len):
  385. label_texts_ensemble_matrix[offset: offset + ltl, lbl] = 1
  386. offset += ltl
  387. return label_texts, label_texts_ensemble_matrix
  388. def dump(self, writer, output):
  389. filenames = output['filenames']
  390. image_ids = output['image_ids']
  391. label_names = output['label_names']
  392. captions = output['captions']
  393. tags = output['tags']
  394. prediction = self.tensor2numpy(output['prediction'])
  395. score = self.tensor2numpy(output['score'])
  396. labels = self.tensor2numpy(output['labels'])
  397. for _idx in range(len(filenames)):
  398. res = {
  399. 'image_id': int(image_ids[_idx]),
  400. 'filename': filenames[_idx],
  401. 'label': int(labels[_idx]),
  402. 'label_name': label_names[_idx],
  403. 'caption': captions[_idx],
  404. 'tag': tags[_idx],
  405. 'prediction': int(prediction[_idx]),
  406. 'score': [float('%.8f' % s) for s in score[_idx]]
  407. }
  408. writer.write(json.dumps(res, ensure_ascii=False) + '\n')
  409. writer.flush()