builder.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. # -------------------------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. #
  5. # MIT License
  6. #
  7. # Permission is hereby granted, free of charge, to any person obtaining a copy
  8. # of this software and associated documentation files (the "Software"), to deal
  9. # in the Software without restriction, including without limitation the rights
  10. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11. # copies of the Software, and to permit persons to whom the Software is
  12. # furnished to do so, subject to the following conditions:
  13. #
  14. # The above copyright notice and this permission notice shall be included in all
  15. # copies or substantial portions of the Software.
  16. #
  17. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  23. # SOFTWARE
  24. #
  25. # Written by Ze Liu, Zhenda Xie
  26. # Modified by Jiarui Xu
  27. # -------------------------------------------------------------------------
  28. import os.path as osp
  29. import random
  30. import warnings
  31. from functools import partial
  32. import nltk
  33. import numpy as np
  34. import torch
  35. import torch.distributed as dist
  36. import webdataset as wds
  37. from braceexpand import braceexpand
  38. from mmcv.parallel import collate
  39. from timm.data import create_transform
  40. from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  41. from timm.data.transforms import _pil_interp
  42. from torchvision import transforms
  43. from .formatting import ToDataContainer
  44. from .tokenizer import SimpleTokenizer
  45. def worker_init_fn(worker_id, num_workers, rank, seed):
  46. # The seed of each worker equals to
  47. # num_worker * rank + worker_id + user_seed
  48. worker_seed = num_workers * rank + worker_id + seed
  49. np.random.seed(worker_seed)
  50. random.seed(worker_seed)
  51. def build_loader(config):
  52. local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0
  53. dataset_train = build_dataset(is_train=True, config=config)
  54. print(f'local rank {local_rank} / global rank {dist.get_rank()} \
  55. successfully build train dataset')
  56. dataset_val = build_dataset(is_train=False, config=config)
  57. print(f'local rank {local_rank} / global rank {dist.get_rank()} \
  58. successfully build val dataset')
  59. dc_collate = partial(collate, samples_per_gpu=config.batch_size)
  60. train_len = len(dataset_train)
  61. init_fn = partial(worker_init_fn, num_workers=config.num_workers, rank=dist.get_rank(), seed=config.seed)
  62. data_loader_train = wds.WebLoader(
  63. dataset_train.batched(config.batch_size, dc_collate, partial=False),
  64. batch_size=None,
  65. shuffle=False,
  66. num_workers=config.num_workers,
  67. pin_memory=config.pin_memory,
  68. persistent_workers=config.num_workers > 0,
  69. worker_init_fn=init_fn)
  70. train_nbatches = max(1, train_len // (config.batch_size * dist.get_world_size()))
  71. data_loader_train = (data_loader_train.with_epoch(train_nbatches).with_length(train_nbatches))
  72. data_loader_val = wds.WebLoader(
  73. dataset_val.batched(config.batch_size, dc_collate),
  74. batch_size=None,
  75. shuffle=False,
  76. num_workers=config.num_workers,
  77. pin_memory=config.pin_memory,
  78. persistent_workers=config.num_workers > 0,
  79. worker_init_fn=init_fn)
  80. val_len = len(dataset_val)
  81. val_nbatches = max(1, val_len // (config.batch_size * dist.get_world_size()))
  82. data_loader_val = (data_loader_val.with_epoch(val_nbatches).with_length(val_nbatches))
  83. return dataset_train, dataset_val, data_loader_train, data_loader_val
  84. def warn_and_continue(exn):
  85. """Call in an exception handler to ignore any exception, issue a warning,
  86. and continue."""
  87. warnings.warn(repr(exn))
  88. return True
  89. def build_dataset(is_train, config):
  90. img_transform = build_img_transform(is_train, config.img_aug)
  91. text_transform = build_text_transform(is_train, config.text_aug)
  92. split = 'train' if is_train else 'val'
  93. dataset_type = None
  94. tar_file_list = []
  95. total_length = 0
  96. for ds in config.dataset[split]:
  97. ds_meta = config.dataset.meta[ds]
  98. if dataset_type is None:
  99. dataset_type = ds_meta.type
  100. else:
  101. assert dataset_type == ds_meta.type, \
  102. 'All datasets must be of the same type'
  103. prefix = ds_meta.prefix
  104. path = ds_meta.path
  105. length = ds_meta.length
  106. cur_tar_file_list = []
  107. for tar_file in braceexpand(osp.join(path, prefix)):
  108. if osp.exists(tar_file):
  109. cur_tar_file_list.append(tar_file)
  110. print(f'Found {len(cur_tar_file_list)} files for dataset {ds}')
  111. tar_file_list.extend(cur_tar_file_list)
  112. total_length += length
  113. print(f'Found {len(tar_file_list)} files in total for split {split}')
  114. # yapf: disable
  115. if is_train:
  116. dataset = ( # noqa
  117. wds.WebDataset(tar_file_list, repeat=True, handler=warn_and_continue)
  118. .shuffle(config.shuffle_buffer)
  119. .decode('pil', handler=warn_and_continue)
  120. .rename(image='jpg;png;jpeg', text='text;txt', keep=False, handler=warn_and_continue)
  121. .map_dict(image=img_transform, text=text_transform, handler=warn_and_continue)
  122. .with_length(total_length))
  123. else:
  124. # zero shot classification validation
  125. dataset = ( # noqa
  126. wds.WebDataset(tar_file_list, repeat=False, handler=warn_and_continue)
  127. .shuffle(0)
  128. .decode('pil', handler=warn_and_continue)
  129. .rename(image='jpg;png;jpeg', target='cls', keep=False)
  130. .map_dict(image=img_transform, target=ToDataContainer())
  131. .slice(dist.get_rank(), total_length, dist.get_world_size())
  132. .with_length(total_length))
  133. # yapf: enable
  134. return dataset
  135. def build_img_transform(is_train, config, with_dc=True):
  136. if not config.deit_aug:
  137. if is_train:
  138. transform = transforms.Compose([
  139. transforms.RandomResizedCrop(config.img_size, scale=config.img_scale),
  140. transforms.RandomHorizontalFlip(),
  141. transforms.ToTensor(),
  142. transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
  143. ])
  144. else:
  145. transform = transforms.Compose([
  146. transforms.Resize(config.img_size + 32),
  147. transforms.CenterCrop(config.img_size),
  148. transforms.ToTensor(),
  149. transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
  150. ])
  151. return transform
  152. if is_train:
  153. # this should always dispatch to transforms_imagenet_train
  154. transform = create_transform(
  155. input_size=config.img_size,
  156. is_training=True,
  157. color_jitter=config.color_jitter if config.color_jitter > 0 else None,
  158. auto_augment=config.auto_augment if config.auto_augment != 'none' else None,
  159. re_prob=config.re_prob,
  160. re_mode=config.re_mode,
  161. re_count=config.re_count,
  162. interpolation=config.interpolation,
  163. )
  164. else:
  165. size = int((256 / 224) * config.img_size)
  166. transform = transforms.Compose([
  167. transforms.Resize(size, interpolation=_pil_interp(config.interpolation)),
  168. transforms.CenterCrop(config.img_size),
  169. transforms.ToTensor(),
  170. transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
  171. ])
  172. if with_dc:
  173. transform = transforms.Compose([*transform.transforms, ToDataContainer()])
  174. return transform
  175. def build_text_transform(is_train, config, with_dc=True):
  176. local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0
  177. if config.multi_label and is_train:
  178. # only down on local rank 0
  179. if local_rank == 0:
  180. nltk.download('popular')
  181. transform = WordAugTokenizeWrapper(
  182. Tokenize(SimpleTokenizer(), max_seq_len=config.max_seq_len),
  183. max_word=config.multi_label,
  184. word_type=config.word_type)
  185. else:
  186. transform = Tokenize(SimpleTokenizer(), max_seq_len=config.max_seq_len)
  187. if with_dc:
  188. transform = transforms.Compose([transform, ToDataContainer()])
  189. return transform
  190. class Tokenize:
  191. def __init__(self, tokenizer, max_seq_len=77, truncate=True):
  192. self.tokenizer = tokenizer
  193. self.max_seq_len = max_seq_len
  194. self.truncate = truncate
  195. def __call__(self, texts):
  196. expanded_dim = False
  197. if isinstance(texts, str):
  198. texts = [texts]
  199. expanded_dim = True
  200. sot_token = self.tokenizer.encoder['<|startoftext|>']
  201. eot_token = self.tokenizer.encoder['<|endoftext|>']
  202. all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
  203. result = torch.zeros(len(all_tokens), self.max_seq_len, dtype=torch.long)
  204. for i, tokens in enumerate(all_tokens):
  205. if len(tokens) > self.max_seq_len:
  206. if self.truncate:
  207. tokens = tokens[:self.max_seq_len]
  208. tokens[-1] = eot_token
  209. else:
  210. raise RuntimeError(f'Input {texts[i]} is too long for context length {self.max_seq_len}')
  211. result[i, :len(tokens)] = torch.tensor(tokens)
  212. if expanded_dim:
  213. return result[0]
  214. return result
  215. class WordAugTokenizeWrapper:
  216. def __init__(self, tokenize, max_word=3, template_set='full', word_type='noun'):
  217. self.tokenize = tokenize
  218. self.max_word = max_word
  219. from .imagenet_template import (full_imagenet_templates, sub_imagenet_template, simple_imagenet_template,
  220. identity_template)
  221. assert template_set in ['full', 'subset', 'simple', 'identity']
  222. if template_set == 'full':
  223. templates = full_imagenet_templates
  224. elif template_set == 'subset':
  225. templates = sub_imagenet_template
  226. elif template_set == 'simple':
  227. templates = simple_imagenet_template
  228. elif template_set == 'identity':
  229. templates = identity_template
  230. else:
  231. raise ValueError
  232. self.templates = templates
  233. assert word_type in ['noun', 'noun_phrase']
  234. self.word_type = word_type
  235. def get_tag(self, tokenized, tags):
  236. if not isinstance(tags, (list, tuple)):
  237. tags = [tags]
  238. ret = []
  239. for (word, pos) in nltk.pos_tag(tokenized):
  240. for tag in tags:
  241. if pos == tag:
  242. ret.append(word)
  243. return ret
  244. def get_noun_phrase(self, tokenized):
  245. # Taken from Su Nam Kim Paper...
  246. grammar = r"""
  247. NBAR:
  248. {<NN.*|JJ>*<NN.*>} # Nouns and Adjectives, terminated with Nouns
  249. NP:
  250. {<NBAR>}
  251. {<NBAR><IN><NBAR>} # Above, connected with in/of/etc...
  252. """
  253. chunker = nltk.RegexpParser(grammar)
  254. chunked = chunker.parse(nltk.pos_tag(tokenized))
  255. continuous_chunk = []
  256. current_chunk = []
  257. for subtree in chunked:
  258. if isinstance(subtree, nltk.Tree):
  259. current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
  260. elif current_chunk:
  261. named_entity = ' '.join(current_chunk)
  262. if named_entity not in continuous_chunk:
  263. continuous_chunk.append(named_entity)
  264. current_chunk = []
  265. else:
  266. continue
  267. return continuous_chunk
  268. def __call__(self, text):
  269. assert isinstance(text, str)
  270. tokenized = nltk.word_tokenize(text)
  271. nouns = []
  272. if len(tokenized) > 0:
  273. if self.word_type == 'noun':
  274. nouns = self.get_tag(tokenized, ['NN', 'NNS', 'NNP', 'VBG', 'VB', 'VBD', 'VBN', 'VBP', 'VBZ'])
  275. elif self.word_type == 'noun_phrase':
  276. nouns = self.get_noun_phrase(tokenized)
  277. else:
  278. raise ValueError('word_type must be noun or noun_phrase')
  279. prompt_texts = []
  280. if len(nouns) > 0:
  281. select_nouns = np.random.choice(nouns, min(self.max_word, len(nouns)), replace=False)
  282. prompt_texts = [np.random.choice(self.templates).format(noun) for noun in select_nouns]
  283. if len(prompt_texts) < self.max_word:
  284. prompt_texts += [text] * (self.max_word - len(prompt_texts))
  285. texts = [text] + prompt_texts
  286. return self.tokenize(texts)