123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275 |
- # -------------------------------------------------------------------------
- # Written by Jilan Xu
- # -------------------------------------------------------------------------
- from re import L
- import torch
- import json
- import os.path as osp
- import requests
- import numpy as np
- import time
- import ast
- from typing import List
- from torch.utils.data import Dataset
- import random
- import os
- import pandas as pd
- import omegaconf
- import clip
- from .tokenizer import SimpleTokenizer
- from .imagenet_template import full_imagenet_templates
- from nltk.stem import WordNetLemmatizer
- from PIL import Image
- from ipdb import set_trace
- import io
- lemmatizer = WordNetLemmatizer()
- ### frequently appeared 100 entities ###
- TOP_CLASSES_1=[
- 'people', 'man', 'men', 'woman', 'women', 'girl', 'boy', 'lady', 'kid', 'child', 'children', 'baby', 'student', 'bride', 'groom', 'couple', 'prince', 'princess', \
- 'car', 'bus', 'truck', 'motorcycle', 'train', 'bicycle', 'boat', 'aeroplane', 'airplane', 'motorbike', 'bike',\
- 'cup', 'bottle', 'bowl', 'knife', 'spoon', 'glass', 'fork',\
- 'chair', 'table', 'bench', 'clock', 'laptop', 'light', 'vase', 'plant', 'remote', 'microwave', 'toaster', 'oven','mouse', 'keyboard','sofa', 'monitor','desk', 'tv','TV', 'couch', 'flower','refrigerator', \
- 'house', 'building', 'hotel',\
- 'handbag', 'umbrella','book', 'backpack', 'phone', 'shirt', 'tie', 'suitcase','T-shirt', 'bag', 'box', \
- 'sink','bed','toilet',\
- 'cat','dog', 'horse', 'bird','cow', 'sheep' ,'elephant', 'bear', 'zebra', 'giraffe', \
- 'ball', 'racket', 'skateboard', 'skis', 'snowboard', 'surfboard', 'kite', \
- 'pizza', 'cake', 'apple', 'banana', 'sandwich', 'orange', 'carrot', 'donut' ,\
- ]
- ### some of the entities are similar, map them to a single one ###
- syn_dict = {
- 'people':'people', 'man':'people', 'men':'people', 'woman':'people', 'women':'people', 'girl':'people', 'boy':'people', 'lady':'people', 'kid':'people', 'child':'people', 'children':'people', 'baby':'people', 'student':'people', 'bride':'people', 'groom':'people', 'couple':'people', 'prince':'people', 'princess':'people',\
- 'airplane': 'aeroplane','motorbike': 'motorcycle','bike': 'bicycle',\
- 'TV':'tv', 'desk': 'table', 'couch':'sofa',\
- 'building': 'house', 'hotel': 'house', \
- 'T-shirt': 'shirt','T-Shirt': 'shirt', 'handbag': 'bag', \
- }
- ### unique entities ###
- TOP_UNIQUE_CLASSES = [
- 'people', 'car', 'bus', 'truck', 'motorcycle', \
- 'train', 'bicycle', 'boat', 'aeroplane', 'cup', \
- 'bottle', 'bowl', 'knife', 'spoon', 'glass', \
- 'fork', 'chair', 'table', 'bench', 'clock', \
- 'laptop', 'light', 'vase', 'plant', 'remote',\
- 'microwave', 'toaster', 'oven','mouse', 'keyboard',\
- 'sofa', 'monitor', 'tv', 'flower','refrigerator', \
- 'house', 'bag', 'umbrella','book', 'backpack', \
- 'phone', 'shirt', 'tie', 'suitcase', 'box',\
- 'sink','bed','toilet', 'cat','dog', \
- 'horse', 'bird','cow', 'sheep' ,'elephant', \
- 'bear', 'zebra', 'giraffe', 'ball', 'racket', \
- 'skateboard', 'skis', 'snowboard', 'surfboard', 'kite',\
- 'pizza', 'cake', 'apple', 'banana', 'sandwich',\
- 'orange', 'carrot', 'donut' ,\
- ]
- TOP_UNIQUE_CLASSES_IDX = {}
- for i, x in enumerate(TOP_UNIQUE_CLASSES):
- TOP_UNIQUE_CLASSES_IDX[x] = i
- class ClipDataset(Dataset):
- def __init__(self, root_dir, meta_file, img_transform=None, text_transform=None,
- read_from='dir',
- label_texts_ensemble='none', split='train',
- cross_image=False, use_entity=True, mask_type='class', use_distilbert=True
- ):
-
- self.root_dir = root_dir
- self.meta_file = meta_file
- self.metas = pd.read_csv(self.meta_file)
- print(f'Total {len(self.metas)} samples')
-
- self.read_from = read_from
- if self.read_from == 'petrel':
- from petrel_client.client import Client
- self.client = Client()
-
- self.img_transform = img_transform
- self.text_transform = text_transform
- self.label_texts_ensemble = label_texts_ensemble
- self.split=split
- self.cross_image = cross_image
- self.use_entity = use_entity
- self.tokenizer = SimpleTokenizer()
- self.mask_type = mask_type
- self.use_distilbert = use_distilbert
- def __len__(self):
- return len(self.metas)
- def load_image(self, filename):
- filename = os.path.join(self.root_dir, filename)
- if self.read_from == 'dir':
- img = Image.open(filename).convert('RGB')
- return img
- elif self.read_from == 'petrel':
- value = self.client.get(filename)
- img_bytes = np.frombuffer(value, dtype=np.uint8)
- with Image.open(io.BytesIO(img_bytes)) as img:
- img = img.convert('RGB')
- return img
- else:
- raise NotImplementedError
- def _load_meta(self, idx):
- return self.metas.iloc[idx]
-
- def sample_cross_image(self, curr_meta):
- pair_index = curr_meta['pairindex']
- pair_entity = curr_meta['pairentity']
- pair_index_list = ast.literal_eval(pair_index)
- pair_entity_list = ast.literal_eval(pair_entity)
-
- sample_index = np.random.randint(0, len(pair_index_list))
-
- index = pair_index_list[sample_index]
- entity = pair_entity_list[sample_index]
-
- pair_meta = self._load_meta(index)
- img = self.load_image(pair_meta['image_id'])
- caption = pair_meta['caption']
- return img, caption, entity
- def __getitem__(self, idx):
- curr_meta = self._load_meta(idx)
- filename = curr_meta['image_id']
- raw_caption = caption = curr_meta['caption']
- label = int(curr_meta['label']) if 'label' in curr_meta else -1
- ret_info = {}
- try:
-
- assert self.is_contains_chinese(caption) == False
- img = self.load_image(filename)
- if self.img_transform is not None:
- image = self.img_transform(img)
-
- if self.text_transform is not None:
- if self.split == 'train':
- ### for clip TextTransformer, captions are here tokenised ###
- ### for bert/distilbert, text transform are used to select nouns, captions will be tokensized later ###
- caption, nouns, locs, prompt_texts = self.text_transform(caption)
-
- if self.use_entity:
- ### A feasible option here is to pre-process question and answers to speed-up data loading ###
- if self.use_distilbert:
- ### bert/distilbert-like, questions/answers will be tokenised later ###
- raw_question, question, raw_answer, answer = self.build_question_and_answer_for_distilbert(raw_caption, nouns)
- else:
- ### clip TextTransformer-like, questions/answers are tokenised ###
- raw_question, question, raw_answer, answer = self.build_question_and_answer(raw_caption, nouns)
-
- ret_info['question'] = question
- ret_info['answer'] = answer
- ret_info['raw_question'] = raw_question
- ret_info['raw_answer'] = raw_answer
- if self.cross_image:
- crossimg, crosscaption, crossentity = self.sample_cross_image(curr_meta)
- crossimg = self.img_transform(crossimg)
- crossentity = 'A photo of ' + crossentity
- ret_info['cross_image'] = crossimg
- ret_info['cross_entity'] = crossentity
- ret_info['cross_caption'] = crosscaption
- else:
- caption = self.text_transform(caption)
-
- ret_info['image'] = image
- ret_info['caption'] = caption
- ret_info['raw_caption'] = raw_caption
- ret_info['target'] = label
- # ret_info['filename'] = filename
- return ret_info
-
- except Exception as e:
- return self.__getitem__(np.random.randint(0, len(self.metas)))
-
- def judge_noun(self, n):
- n = n.replace('.', '')
- ans = n
- ### conduct Lemmatization ###
- ans = lemmatizer.lemmatize(ans.lower())
-
- if ans in syn_dict:
- ans = syn_dict[ans]
-
- if ans in TOP_UNIQUE_CLASSES:
- return 1, ans
- return 0, n
-
- def build_question_and_answer(self, caption, nouns):
- words = caption.split(' ')
- question = ''
- ans_list = []
- token_mapper = {}
- word_mapper = {}
- assert self.mask_type == 'class'
- for word in words:
- word = word.strip("'s").strip(' ').strip('\n')
- word_flag, newword = self.judge_noun(word)
- if word_flag == 1:
- question = question + newword + ' '
- ans_list.append(newword)
- token_id = self.tokenizer.encode(newword)[0]
- token_mapper[token_id] = TOP_UNIQUE_CLASSES_IDX[newword]
- word_mapper[token_id] = 332 ### this is 'M'
- else:
- question = question + word + ' '
-
- question = question.replace("'", '').strip()
- raw_question = question
-
- question, _, _, _ = self.text_transform(raw_question)
- question = torch.tensor([word_mapper[int(word)] if int(word) in word_mapper else word for word in question])
- # raw_answer = 'A photo of ' + ' and '.join(list(set(ans_list))) ## unique words
- raw_answer = random.choice(full_imagenet_templates).split('{}')[0] + ' and '.join(list(set(ans_list)))
- answer, _, _, _ = self.text_transform(raw_answer)
-
- return raw_question, question, raw_answer, answer
- def build_question_and_answer_for_distilbert(self, caption, nouns):
- words = caption.split(' ')
- question = ''
- entity_list = []
- ### default, mask all entites ###
- assert self.mask_type == 'class'
- for word in words:
- word = word.strip("'s").strip(' ').strip('\n')
- word_flag, newword = self.judge_noun(word)
- if word_flag == 1:
- question = question + '[MASK]' + ' '
- entity_list.append(newword)
- else:
- question = question + word + ' '
-
- question = question.replace("'", '').strip()
- raw_question = question
- #### build and transform answers ###
- # raw_answer = 'A photo of ' + ' and '.join(list(set(ans_list))) ## unique words
- raw_answer = random.choice(full_imagenet_templates).split('{}')[0] + ' and '.join(list(set(entity_list)))
- return raw_question, None, raw_answer, None
- def is_contains_chinese(self, strs):
- for _char in strs:
- if '\u4e00' <= _char <= '\u9fa5':
- return True
- return False
|