123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337 |
- # -------------------------------------------------------------------------
- # Swin Transformer
- # Copyright (c) 2021 Microsoft
- #
- # MIT License
- #
- # Permission is hereby granted, free of charge, to any person obtaining a copy
- # of this software and associated documentation files (the "Software"), to deal
- # in the Software without restriction, including without limitation the rights
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- # copies of the Software, and to permit persons to whom the Software is
- # furnished to do so, subject to the following conditions:
- #
- # The above copyright notice and this permission notice shall be included in all
- # copies or substantial portions of the Software.
- #
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- # SOFTWARE
- #
- # Written by Ze Liu, Zhenda Xie
- # Modified by Jiarui Xu
- # -------------------------------------------------------------------------
- import os.path as osp
- import random
- import warnings
- from functools import partial
- import nltk
- import numpy as np
- import torch
- import torch.distributed as dist
- import webdataset as wds
- from braceexpand import braceexpand
- from mmcv.parallel import collate
- from timm.data import create_transform
- from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.data.transforms import _pil_interp
- from torchvision import transforms
- from .formatting import ToDataContainer
- from .tokenizer import SimpleTokenizer
- def worker_init_fn(worker_id, num_workers, rank, seed):
- # The seed of each worker equals to
- # num_worker * rank + worker_id + user_seed
- worker_seed = num_workers * rank + worker_id + seed
- np.random.seed(worker_seed)
- random.seed(worker_seed)
- def build_loader(config):
- local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0
- dataset_train = build_dataset(is_train=True, config=config)
- print(f'local rank {local_rank} / global rank {dist.get_rank()} \
- successfully build train dataset')
- dataset_val = build_dataset(is_train=False, config=config)
- print(f'local rank {local_rank} / global rank {dist.get_rank()} \
- successfully build val dataset')
- dc_collate = partial(collate, samples_per_gpu=config.batch_size)
- train_len = len(dataset_train)
- init_fn = partial(worker_init_fn, num_workers=config.num_workers, rank=dist.get_rank(), seed=config.seed)
- data_loader_train = wds.WebLoader(
- dataset_train.batched(config.batch_size, dc_collate, partial=False),
- batch_size=None,
- shuffle=False,
- num_workers=config.num_workers,
- pin_memory=config.pin_memory,
- persistent_workers=config.num_workers > 0,
- worker_init_fn=init_fn)
- train_nbatches = max(1, train_len // (config.batch_size * dist.get_world_size()))
- data_loader_train = (data_loader_train.with_epoch(train_nbatches).with_length(train_nbatches))
- data_loader_val = wds.WebLoader(
- dataset_val.batched(config.batch_size, dc_collate),
- batch_size=None,
- shuffle=False,
- num_workers=config.num_workers,
- pin_memory=config.pin_memory,
- persistent_workers=config.num_workers > 0,
- worker_init_fn=init_fn)
- val_len = len(dataset_val)
- val_nbatches = max(1, val_len // (config.batch_size * dist.get_world_size()))
- data_loader_val = (data_loader_val.with_epoch(val_nbatches).with_length(val_nbatches))
- return dataset_train, dataset_val, data_loader_train, data_loader_val
- def warn_and_continue(exn):
- """Call in an exception handler to ignore any exception, issue a warning,
- and continue."""
- warnings.warn(repr(exn))
- return True
- def build_dataset(is_train, config):
- img_transform = build_img_transform(is_train, config.img_aug)
- text_transform = build_text_transform(is_train, config.text_aug)
- split = 'train' if is_train else 'val'
- dataset_type = None
- tar_file_list = []
- total_length = 0
- for ds in config.dataset[split]:
- ds_meta = config.dataset.meta[ds]
- if dataset_type is None:
- dataset_type = ds_meta.type
- else:
- assert dataset_type == ds_meta.type, \
- 'All datasets must be of the same type'
- prefix = ds_meta.prefix
- path = ds_meta.path
- length = ds_meta.length
- cur_tar_file_list = []
- for tar_file in braceexpand(osp.join(path, prefix)):
- if osp.exists(tar_file):
- cur_tar_file_list.append(tar_file)
- print(f'Found {len(cur_tar_file_list)} files for dataset {ds}')
- tar_file_list.extend(cur_tar_file_list)
- total_length += length
- print(f'Found {len(tar_file_list)} files in total for split {split}')
- # yapf: disable
- if is_train:
- dataset = ( # noqa
- wds.WebDataset(tar_file_list, repeat=True, handler=warn_and_continue)
- .shuffle(config.shuffle_buffer)
- .decode('pil', handler=warn_and_continue)
- .rename(image='jpg;png;jpeg', text='text;txt', keep=False, handler=warn_and_continue)
- .map_dict(image=img_transform, text=text_transform, handler=warn_and_continue)
- .with_length(total_length))
- else:
- # zero shot classification validation
- dataset = ( # noqa
- wds.WebDataset(tar_file_list, repeat=False, handler=warn_and_continue)
- .shuffle(0)
- .decode('pil', handler=warn_and_continue)
- .rename(image='jpg;png;jpeg', target='cls', keep=False)
- .map_dict(image=img_transform, target=ToDataContainer())
- .slice(dist.get_rank(), total_length, dist.get_world_size())
- .with_length(total_length))
- # yapf: enable
- return dataset
- def build_img_transform(is_train, config, with_dc=True):
- if not config.deit_aug:
- if is_train:
- transform = transforms.Compose([
- transforms.RandomResizedCrop(config.img_size, scale=config.img_scale),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
- ])
- else:
- transform = transforms.Compose([
- transforms.Resize(config.img_size + 32),
- transforms.CenterCrop(config.img_size),
- transforms.ToTensor(),
- transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
- ])
- return transform
- if is_train:
- # this should always dispatch to transforms_imagenet_train
- transform = create_transform(
- input_size=config.img_size,
- is_training=True,
- color_jitter=config.color_jitter if config.color_jitter > 0 else None,
- auto_augment=config.auto_augment if config.auto_augment != 'none' else None,
- re_prob=config.re_prob,
- re_mode=config.re_mode,
- re_count=config.re_count,
- interpolation=config.interpolation,
- )
- else:
- size = int((256 / 224) * config.img_size)
- transform = transforms.Compose([
- transforms.Resize(size, interpolation=_pil_interp(config.interpolation)),
- transforms.CenterCrop(config.img_size),
- transforms.ToTensor(),
- transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
- ])
- if with_dc:
- transform = transforms.Compose([*transform.transforms, ToDataContainer()])
- return transform
- def build_text_transform(is_train, config, with_dc=True):
- local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0
- if config.multi_label and is_train:
- # only down on local rank 0
- if local_rank == 0:
- nltk.download('popular')
- transform = WordAugTokenizeWrapper(
- Tokenize(SimpleTokenizer(), max_seq_len=config.max_seq_len),
- max_word=config.multi_label,
- word_type=config.word_type)
- else:
- transform = Tokenize(SimpleTokenizer(), max_seq_len=config.max_seq_len)
- if with_dc:
- transform = transforms.Compose([transform, ToDataContainer()])
- return transform
- class Tokenize:
- def __init__(self, tokenizer, max_seq_len=77, truncate=True):
- self.tokenizer = tokenizer
- self.max_seq_len = max_seq_len
- self.truncate = truncate
- def __call__(self, texts):
- expanded_dim = False
- if isinstance(texts, str):
- texts = [texts]
- expanded_dim = True
- sot_token = self.tokenizer.encoder['<|startoftext|>']
- eot_token = self.tokenizer.encoder['<|endoftext|>']
- all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
- result = torch.zeros(len(all_tokens), self.max_seq_len, dtype=torch.long)
- for i, tokens in enumerate(all_tokens):
- if len(tokens) > self.max_seq_len:
- if self.truncate:
- tokens = tokens[:self.max_seq_len]
- tokens[-1] = eot_token
- else:
- raise RuntimeError(f'Input {texts[i]} is too long for context length {self.max_seq_len}')
- result[i, :len(tokens)] = torch.tensor(tokens)
- if expanded_dim:
- return result[0]
- return result
- class WordAugTokenizeWrapper:
- def __init__(self, tokenize, max_word=3, template_set='full', word_type='noun'):
- self.tokenize = tokenize
- self.max_word = max_word
- from .imagenet_template import (full_imagenet_templates, sub_imagenet_template, simple_imagenet_template,
- identity_template)
- assert template_set in ['full', 'subset', 'simple', 'identity']
- if template_set == 'full':
- templates = full_imagenet_templates
- elif template_set == 'subset':
- templates = sub_imagenet_template
- elif template_set == 'simple':
- templates = simple_imagenet_template
- elif template_set == 'identity':
- templates = identity_template
- else:
- raise ValueError
- self.templates = templates
- assert word_type in ['noun', 'noun_phrase']
- self.word_type = word_type
- def get_tag(self, tokenized, tags):
- if not isinstance(tags, (list, tuple)):
- tags = [tags]
- ret = []
- for (word, pos) in nltk.pos_tag(tokenized):
- for tag in tags:
- if pos == tag:
- ret.append(word)
- return ret
- def get_noun_phrase(self, tokenized):
- # Taken from Su Nam Kim Paper...
- grammar = r"""
- NBAR:
- {<NN.*|JJ>*<NN.*>} # Nouns and Adjectives, terminated with Nouns
- NP:
- {<NBAR>}
- {<NBAR><IN><NBAR>} # Above, connected with in/of/etc...
- """
- chunker = nltk.RegexpParser(grammar)
- chunked = chunker.parse(nltk.pos_tag(tokenized))
- continuous_chunk = []
- current_chunk = []
- for subtree in chunked:
- if isinstance(subtree, nltk.Tree):
- current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
- elif current_chunk:
- named_entity = ' '.join(current_chunk)
- if named_entity not in continuous_chunk:
- continuous_chunk.append(named_entity)
- current_chunk = []
- else:
- continue
- return continuous_chunk
- def __call__(self, text):
- assert isinstance(text, str)
- tokenized = nltk.word_tokenize(text)
- nouns = []
- if len(tokenized) > 0:
- if self.word_type == 'noun':
- nouns = self.get_tag(tokenized, ['NN', 'NNS', 'NNP', 'VBG', 'VB', 'VBD', 'VBN', 'VBP', 'VBZ'])
- elif self.word_type == 'noun_phrase':
- nouns = self.get_noun_phrase(tokenized)
- else:
- raise ValueError('word_type must be noun or noun_phrase')
- prompt_texts = []
- if len(nouns) > 0:
- select_nouns = np.random.choice(nouns, min(self.max_word, len(nouns)), replace=False)
- prompt_texts = [np.random.choice(self.templates).format(noun) for noun in select_nouns]
- if len(prompt_texts) < self.max_word:
- prompt_texts += [text] * (self.max_word - len(prompt_texts))
- texts = [text] + prompt_texts
- return self.tokenize(texts)
|