123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- from typing import List
- from torch.utils.data import Dataset
- import os.path as osp
- import logging
- import torch
- from utils.iotools import read_image
- from utils.simple_tokenizer import SimpleTokenizer
- from prettytable import PrettyTable
- import random
- import regex as re
- import copy
- class BaseDataset(object):
- """
- Base class of text to image reid dataset
- """
- logger = logging.getLogger("IRRA.dataset")
- def show_dataset_info(self):
- num_train_pids, num_train_imgs, num_train_captions = len(
- self.train_id_container), len(self.train_annos), len(self.train)
- num_test_pids, num_test_imgs, num_test_captions = len(
- self.test_id_container), len(self.test_annos), len(
- self.test['captions'])
- num_val_pids, num_val_imgs, num_val_captions = len(
- self.val_id_container), len(self.val_annos), len(
- self.val['captions'])
- # TODO use prettytable print comand line table
- self.logger.info(f"{self.__class__.__name__} Dataset statistics:")
- table = PrettyTable(['subset', 'ids', 'images', 'captions'])
- table.add_row(
- ['train', num_train_pids, num_train_imgs, num_train_captions])
- table.add_row(
- ['test', num_test_pids, num_test_imgs, num_test_captions])
- table.add_row(['val', num_val_pids, num_val_imgs, num_val_captions])
- self.logger.info('\n' + str(table))
- def tokenize(caption: str, tokenizer, text_length=77, truncate=True) -> torch.LongTensor:
- sot_token = tokenizer.encoder["<|startoftext|>"]
- eot_token = tokenizer.encoder["<|endoftext|>"]
- tokens = [sot_token] + tokenizer.encode(caption) + [eot_token]
- result = torch.zeros(text_length, dtype=torch.long)
- if len(tokens) > text_length:
- if truncate:
- tokens = tokens[:text_length]
- tokens[-1] = eot_token
- else:
- raise RuntimeError(
- f"Input {caption} is too long for context length {text_length}"
- )
- result[:len(tokens)] = torch.tensor(tokens)
- return result
- class ImageTextDataset(Dataset):
- def __init__(self,
- dataset,
- transform=None,
- text_length: int = 77,
- truncate: bool = True):
- self.dataset = dataset
- self.transform = transform
- self.text_length = text_length
- self.truncate = truncate
- self.tokenizer = SimpleTokenizer()
- def __len__(self):
- return len(self.dataset)
- def __getitem__(self, index):
- pid, image_id, img_path, caption = self.dataset[index]
- img = read_image(img_path)
- if self.transform is not None:
- img = self.transform(img)
- tokens = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate)
- ret = {
- 'pids': pid,
- 'image_ids': image_id,
- 'images': img,
- 'caption_ids': tokens,
- }
- return ret
- class ImageDataset(Dataset):
- def __init__(self, image_pids, img_paths, transform=None):
- self.image_pids = image_pids
- self.img_paths = img_paths
- self.transform = transform
- def __len__(self):
- return len(self.image_pids)
- def __getitem__(self, index):
- pid, img_path = self.image_pids[index], self.img_paths[index]
- img = read_image(img_path)
- if self.transform is not None:
- img = self.transform(img)
- return pid, img
- class TextDataset(Dataset):
- def __init__(self,
- caption_pids,
- captions,
- text_length: int = 77,
- truncate: bool = True):
- self.caption_pids = caption_pids
- self.captions = captions
- self.text_length = text_length
- self.truncate = truncate
- # self.tokenizer = SimpleTokenizer(bpe_path="/home/linkslinks/文档/ai/GroupViT/datasets/bpe_simple_vocab_16e6.txt.gz")
- self.tokenizer = SimpleTokenizer(bpe_path="/mnt/vos-s9gjtkm2/reid/groupvit/GroupViT/datasets/bpe_simple_vocab_16e6.txt.gz")
- def __len__(self):
- return len(self.caption_pids)
- def __getitem__(self, index):
- pid, caption = self.caption_pids[index], self.captions[index]
- caption = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate)
- return pid, caption
- class ImageTextMLMDataset(Dataset):
- def __init__(self,
- dataset,
- transform=None,
- text_length: int = 77,
- truncate: bool = True):
- self.dataset = dataset
- self.transform = transform
- self.text_length = text_length
- self.truncate = truncate
- self.tokenizer = SimpleTokenizer()
- def __len__(self):
- return len(self.dataset)
- def __getitem__(self, index):
- pid, image_id, img_path, caption = self.dataset[index]
- img = read_image(img_path)
- if self.transform is not None:
- img = self.transform(img)
-
- caption_tokens = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate)
- mlm_tokens, mlm_labels = self._build_random_masked_tokens_and_labels(caption_tokens.cpu().numpy())
- ret = {
- 'pids': pid,
- 'image_ids': image_id,
- 'images': img,
- 'caption_ids': caption_tokens,
- 'mlm_ids': mlm_tokens,
- 'mlm_labels': mlm_labels
- }
- return ret
- def _build_random_masked_tokens_and_labels(self, tokens):
- """
- Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
- :param tokens: list of int, tokenized sentence.
- :return: (list of int, list of int), masked tokens and related labels for MLM prediction
- """
- mask = self.tokenizer.encoder["<|mask|>"]
- token_range = list(range(1, len(self.tokenizer.encoder)-3)) # 1 ~ 49405
-
- labels = []
- for i, token in enumerate(tokens):
- if 0 < token < 49405:
- prob = random.random()
- # mask token with 15% probability
- if prob < 0.15:
- prob /= 0.15
- # 80% randomly change token to mask token
- if prob < 0.8:
- tokens[i] = mask
- # 10% randomly change token to random token
- elif prob < 0.9:
- tokens[i] = random.choice(token_range)
- # -> rest 10% randomly keep current token
- # append current token to output (we will predict these later)
- labels.append(token)
- else:
- # no masking token (will be ignored by loss function later)
- labels.append(0)
- else:
- labels.append(0)
-
- if all(l == 0 for l in labels):
- # at least mask 1
- labels[1] = tokens[1]
- tokens[1] = mask
- return torch.tensor(tokens), torch.tensor(labels)
|