builder.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # -------------------------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. #
  5. # MIT License
  6. #
  7. # Permission is hereby granted, free of charge, to any person obtaining a copy
  8. # of this software and associated documentation files (the "Software"), to deal
  9. # in the Software without restriction, including without limitation the rights
  10. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11. # copies of the Software, and to permit persons to whom the Software is
  12. # furnished to do so, subject to the following conditions:
  13. #
  14. # The above copyright notice and this permission notice shall be included in all
  15. # copies or substantial portions of the Software.
  16. #
  17. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  23. # SOFTWARE
  24. #
  25. # Written by Ze Liu, Zhenda Xie
  26. # Modified by Jiarui Xu
  27. # -------------------------------------------------------------------------
  28. # Modified by Jilan Xu
  29. # -------------------------------------------------------------------------
  30. import os.path as osp
  31. import random
  32. import warnings
  33. from functools import partial
  34. import nltk
  35. import numpy as np
  36. import torch
  37. import torch.distributed as dist
  38. from mmcv.parallel import collate
  39. from timm.data import create_transform
  40. from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  41. import timm
  42. if timm.__version__ == '0.6.12':
  43. from timm.data.transforms import str_to_pil_interp as _pil_interp
  44. else:
  45. from timm.data.transforms import _pil_interp
  46. # this works for timm==0.3.2
  47. # from timm.data.transforms import _pil_interp
  48. from torchvision import transforms
  49. import torch.nn as nn
  50. from PIL import ImageFilter,Image
  51. from torch import Tensor
  52. from typing import Tuple, List, Optional
  53. import numbers
  54. import math
  55. import torchvision.transforms.functional as F
  56. import shutil
  57. from .formatting import ToDataContainer
  58. from .tokenizer import SimpleTokenizer
  59. from .clip_dataset import ClipDataset
  60. from ipdb import set_trace
  61. def worker_init_fn(worker_id, num_workers, rank, seed):
  62. # The seed of each worker equals to
  63. # num_worker * rank + worker_id + user_seed
  64. worker_seed = num_workers * rank + worker_id + seed
  65. np.random.seed(worker_seed)
  66. random.seed(worker_seed)
  67. def collate_fn(batch):
  68. img = torch.stack([b['image'] for b in batch])
  69. caption = torch.stack([b['caption'] for b in batch])
  70. raw_caption = [b['raw_caption'] for b in batch]
  71. raw_question = [b['raw_question'] for b in batch] if 'raw_question' in batch[0].keys() else None
  72. raw_answer = [b['raw_answer'] for b in batch] if 'raw_answer' in batch[0].keys() else None
  73. cross_image = torch.stack([b['cross_image'] for b in batch]) if 'cross_image' in batch[0].keys() else None
  74. cross_entity = [b['cross_entity'] for b in batch] if 'cross_entity' in batch[0].keys() else None
  75. question = torch.stack([b['question'] for b in batch]) if 'question' in batch[0].keys() and batch[0]['question'] is not None else None
  76. answer = torch.stack([b['answer'] for b in batch]) if 'answer' in batch[0].keys() and batch[0]['answer'] is not None else None
  77. return {
  78. 'image':img,
  79. 'caption':caption,
  80. 'raw_caption' : raw_caption,
  81. 'raw_question': raw_question,
  82. 'raw_answer': raw_answer,
  83. 'cross_image': cross_image,
  84. 'cross_entity': cross_entity,
  85. 'question': question,
  86. 'answer': answer,
  87. }
  88. def build_loader(config):
  89. local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0
  90. dataset_train = build_dataset(is_train=True, config=config)
  91. print(f'local rank {local_rank} / global rank {dist.get_rank()} \
  92. successfully build train dataset')
  93. dataset_val = build_dataset(is_train=False, config=config)
  94. print(f'local rank {local_rank} / global rank {dist.get_rank()} \
  95. successfully build val dataset')
  96. sampler_train = torch.utils.data.DistributedSampler(dataset_train, shuffle=True)
  97. sampler_val = torch.utils.data.SequentialSampler(dataset_val)
  98. print('train batch size: ', config.train.batch_size)
  99. print('val batch size: ', config.val.batch_size)
  100. data_loader_train = torch.utils.data.DataLoader(
  101. dataset_train,
  102. sampler=sampler_train,
  103. batch_size=config.train.batch_size,
  104. num_workers=config.train.num_workers,
  105. pin_memory=True,
  106. drop_last=True,
  107. persistent_workers=True,
  108. collate_fn=collate_fn, ### NOTEL THIS ###
  109. #shuffle=False,
  110. )
  111. data_loader_val = torch.utils.data.DataLoader(
  112. dataset_val,
  113. sampler=sampler_val,
  114. batch_size=config.val.batch_size,
  115. num_workers=config.val.num_workers,
  116. pin_memory=True,
  117. drop_last=False,
  118. persistent_workers=True,
  119. )
  120. return dataset_train, dataset_val, data_loader_train, data_loader_val
  121. def build_dataset(is_train, config):
  122. img_transform = build_img_transform(is_train, config.img_aug, config.with_dc)
  123. text_transform = build_text_transform(is_train, config.text_aug, config.with_dc)
  124. split = 'train' if is_train else 'val'
  125. dataset = ClipDataset(
  126. root_dir=config[split]['root_dir'],
  127. meta_file=config[split]['meta_file'],
  128. img_transform=img_transform,
  129. text_transform=text_transform,
  130. read_from=config[split]['read_from'],
  131. split=split,
  132. cross_image=config[split].get('cross_image', False),
  133. mask_type=config[split].get('mask_type', 'class'),
  134. use_distilbert=config[split].get('use_distilbert', True),
  135. )
  136. print('dataset len: ', len(dataset))
  137. # for i in range(10):
  138. # t = dataset.__getitem__(i)
  139. # print(t['image'].shape, t['cross_image'].shape)
  140. # print(t['caption'].shape, t['target'])
  141. # print(t['raw_caption'])
  142. # print(t['cross_caption'], '\t', t['cross_entity'])
  143. # print(t['raw_question'], '\t', t['raw_answer'])
  144. # set_trace()
  145. return dataset
  146. def build_img_transform(is_train, config, with_dc=True):
  147. if not config.deit_aug:
  148. if is_train:
  149. transform = transforms.Compose([
  150. transforms.RandomResizedCrop(config.img_size, scale=config.img_scale),
  151. transforms.RandomHorizontalFlip(),
  152. transforms.ToTensor(),
  153. transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
  154. ])
  155. else:
  156. transform = transforms.Compose([
  157. transforms.Resize(config.img_size + 32),
  158. transforms.CenterCrop(config.img_size),
  159. transforms.ToTensor(),
  160. transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
  161. ])
  162. return transform
  163. if is_train:
  164. # this should always dispatch to transforms_imagenet_train
  165. transform = create_transform(
  166. input_size=config.img_size,
  167. is_training=True,
  168. color_jitter=config.color_jitter if config.color_jitter > 0 else None,
  169. auto_augment=config.auto_augment if config.auto_augment != 'none' else None,
  170. re_prob=config.re_prob,
  171. re_mode=config.re_mode,
  172. re_count=config.re_count,
  173. interpolation=config.interpolation,
  174. )
  175. else:
  176. size = int((256 / 224) * config.img_size)
  177. transform = transforms.Compose([
  178. transforms.Resize(size, interpolation=_pil_interp(config.interpolation)),
  179. transforms.CenterCrop(config.img_size),
  180. transforms.ToTensor(),
  181. transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
  182. ])
  183. if with_dc:
  184. transform = transforms.Compose([*transform.transforms, ToDataContainer()])
  185. return transform
  186. def build_text_transform(is_train, config, with_dc=True):
  187. local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0
  188. if is_train:
  189. ### only on local rank 0 ###
  190. if local_rank == 0:
  191. ### download itself or pre-download and give the nltk dir ###
  192. # nltk.download('popular')
  193. nltk.data.path.append('/mnt/petrelfs/xujilan/nltk_data')
  194. transform = WordAugTokenizeWrapper(
  195. Tokenize(SimpleTokenizer(), max_seq_len=config.max_seq_len),
  196. max_word=config.multi_label,
  197. word_type=config.word_type)
  198. else:
  199. transform = Tokenize(SimpleTokenizer(), max_seq_len=config.max_seq_len)
  200. if with_dc:
  201. transform = transforms.Compose([transform, ToDataContainer()])
  202. return transform
  203. class Tokenize:
  204. def __init__(self, tokenizer, max_seq_len=77, truncate=True):
  205. self.tokenizer = tokenizer
  206. self.max_seq_len = max_seq_len
  207. self.truncate = truncate
  208. def __call__(self, texts):
  209. expanded_dim = False
  210. if isinstance(texts, str):
  211. texts = [texts]
  212. expanded_dim = True
  213. sot_token = self.tokenizer.encoder['<|startoftext|>']
  214. eot_token = self.tokenizer.encoder['<|endoftext|>']
  215. all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
  216. result = torch.zeros(len(all_tokens), self.max_seq_len, dtype=torch.long)
  217. for i, tokens in enumerate(all_tokens):
  218. if len(tokens) > self.max_seq_len:
  219. if self.truncate:
  220. tokens = tokens[:self.max_seq_len]
  221. tokens[-1] = eot_token
  222. else:
  223. raise RuntimeError(f'Input {texts[i]} is too long for context length {self.max_seq_len}')
  224. result[i, :len(tokens)] = torch.tensor(tokens)
  225. if expanded_dim:
  226. return result[0]
  227. return result
  228. class WordAugTokenizeWrapper:
  229. def __init__(self, tokenize, max_word=3, template_set='full', word_type='noun'):
  230. self.tokenize = tokenize
  231. self.max_word = max_word
  232. from .imagenet_template import (full_imagenet_templates, sub_imagenet_template, simple_imagenet_template,
  233. identity_template)
  234. assert template_set in ['full', 'subset', 'simple', 'identity']
  235. if template_set == 'full':
  236. templates = full_imagenet_templates
  237. elif template_set == 'subset':
  238. templates = sub_imagenet_template
  239. elif template_set == 'simple':
  240. templates = simple_imagenet_template
  241. elif template_set == 'identity':
  242. templates = identity_template
  243. else:
  244. raise ValueError
  245. self.templates = templates
  246. assert word_type in ['noun', 'noun_phrase']
  247. self.word_type = word_type
  248. def get_tag(self, tokenized, tags):
  249. if not isinstance(tags, (list, tuple)):
  250. tags = [tags]
  251. ret = []
  252. for (word, pos) in nltk.pos_tag(tokenized):
  253. for tag in tags:
  254. if pos == tag:
  255. ret.append(word)
  256. return ret
  257. def get_tag_with_loc(self, tokenized, tags):
  258. if not isinstance(tags, (list, tuple)):
  259. tags = [tags]
  260. ret = []
  261. loc = []
  262. for i, (word, pos) in enumerate(nltk.pos_tag(tokenized)):
  263. for tag in tags:
  264. if pos == tag:
  265. ret.append(word)
  266. loc.append(i)
  267. return ret, loc
  268. def get_noun_phrase(self, tokenized):
  269. # Taken from Su Nam Kim Paper...
  270. grammar = r"""
  271. NBAR:
  272. {<NN.*|JJ>*<NN.*>} # Nouns and Adjectives, terminated with Nouns
  273. NP:
  274. {<NBAR>}
  275. {<NBAR><IN><NBAR>} # Above, connected with in/of/etc...
  276. """
  277. chunker = nltk.RegexpParser(grammar)
  278. chunked = chunker.parse(nltk.pos_tag(tokenized))
  279. continuous_chunk = []
  280. current_chunk = []
  281. for subtree in chunked:
  282. if isinstance(subtree, nltk.Tree):
  283. current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
  284. elif current_chunk:
  285. named_entity = ' '.join(current_chunk)
  286. if named_entity not in continuous_chunk:
  287. continuous_chunk.append(named_entity)
  288. current_chunk = []
  289. else:
  290. continue
  291. return continuous_chunk
  292. def __call__(self, text):
  293. """
  294. Args:
  295. text: str
  296. """
  297. assert isinstance(text, str)
  298. tokenized = nltk.word_tokenize(text)
  299. nouns = []
  300. if len(tokenized) > 0:
  301. if self.word_type == 'noun':
  302. # nouns = self.get_tag(tokenized, ['NN', 'NNS', 'NNP', 'VBG', 'VB', 'VBD', 'VBN', 'VBP', 'VBZ'])
  303. # nouns = self.get_tag(tokenized, ['NN', 'NNS'])
  304. # nouns, locs = self.get_tag_with_loc(tokenized, ['NN', 'NNS'])
  305. nouns, locs = self.get_tag_with_loc(tokenized, ['NN', 'NNS', 'NNP',])
  306. elif self.word_type == 'noun_phrase':
  307. nouns = self.get_noun_phrase(tokenized)
  308. else:
  309. raise ValueError('word_type must be noun or noun_phrase')
  310. ### By default, we use this ###
  311. if self.max_word == 0:
  312. return self.tokenize(text), nouns, locs, text
  313. prompt_texts = []
  314. if len(nouns) > 0:
  315. select_nouns = np.random.choice(nouns, min(self.max_word, len(nouns)), replace=False)
  316. prompt_texts = [np.random.choice(self.templates).format(noun) for noun in select_nouns]
  317. if len(prompt_texts) < self.max_word:
  318. prompt_texts += [text] * (self.max_word - len(prompt_texts))
  319. texts = [text] + prompt_texts
  320. return self.tokenize(texts), nouns, locs, texts