4 Revize 2117e81e56 ... 90ddc4c5f7

Autor SHA1 Zpráva Datum
  Yijun Fu 90ddc4c5f7 代码能运行,但是效果不对 před 1 měsícem
  Yijun Fu 8805f2ac02 feat(evaluate): 更新 reid 评估方法 před 1 měsícem
  Yijun Fu e2a6f804a8 data(cuhkpedes): 更新数据集配置和加载逻辑 před 1 měsícem
  Yijun Fu 70475e7ec1 feat(dataset): 添加 CUHK-PEDES 数据集支持 před 1 měsícem

+ 10 - 6
configs/default.yml

@@ -6,6 +6,7 @@ data:
   # the differences become academic.
   shuffle_buffer: 10000
   seed: ${train.seed}
+  bpe_path: /home/linkslinks/文档/ai/GroupViT/datasets/bpe_simple_vocab_16e6.txt.gz
   dataset:
     meta:
       gcc3m:
@@ -34,14 +35,17 @@ data:
         prefix: imagenet-val-{000000..000049}.tar
         length: 50000
       cuhkpedes_train:
+        name: CUHK-PEDES
         type: img_txt_pair
         path: local_data/cuhkpedes_shards
-        prefix: cuhkpedes-train-{000000..000255}.tar
+        prefix: cuhkpedes-train-{000000..000004}.tar
         length: 34054
       cuhkpedes_val:
+        raw_path: /home/linkslinks/dataset/
+        name: CUHK-PEDES
         type: img_txt_pair
         path: local_data/cuhkpedes_shards
-        prefix: cuhkpedes-val-{000000..000023}.tar
+        prefix: cuhkpedes-val-{000000..000000}.tar
         length: 3078
     train:
       # - gcc3m
@@ -92,9 +96,9 @@ evaluate:
   eval_only: false
   eval_freq: 1
   task:
-    - cls
-    - seg
-    - retrieval
+    # - cls
+    # - seg
+    - reid
   cls:
     save_best: true
     template: subset
@@ -103,7 +107,7 @@ evaluate:
     cfg: segmentation/configs/_base_/datasets/pascal_voc12.py
     template: simple
     opts: []
-  retrieval:
+  reid:
     save_best: true
     template: simple
     opts: []

+ 210 - 0
datasets/bases.py

@@ -0,0 +1,210 @@
+from typing import List
+from torch.utils.data import Dataset
+import os.path as osp
+import logging
+import torch
+from utils.iotools import read_image
+from utils.simple_tokenizer import SimpleTokenizer
+from prettytable import PrettyTable
+import random
+import regex as re
+import copy
+
+
+class BaseDataset(object):
+    """
+    Base class of text to image reid dataset
+    """
+    logger = logging.getLogger("IRRA.dataset")
+
+    def show_dataset_info(self):
+        num_train_pids, num_train_imgs, num_train_captions = len(
+            self.train_id_container), len(self.train_annos), len(self.train)
+        num_test_pids, num_test_imgs, num_test_captions = len(
+            self.test_id_container), len(self.test_annos), len(
+                self.test['captions'])
+        num_val_pids, num_val_imgs, num_val_captions = len(
+            self.val_id_container), len(self.val_annos), len(
+                self.val['captions'])
+
+        # TODO use prettytable print comand line table
+
+        self.logger.info(f"{self.__class__.__name__} Dataset statistics:")
+        table = PrettyTable(['subset', 'ids', 'images', 'captions'])
+        table.add_row(
+            ['train', num_train_pids, num_train_imgs, num_train_captions])
+        table.add_row(
+            ['test', num_test_pids, num_test_imgs, num_test_captions])
+        table.add_row(['val', num_val_pids, num_val_imgs, num_val_captions])
+        self.logger.info('\n' + str(table))
+
+
+def tokenize(caption: str, tokenizer, text_length=77, truncate=True) -> torch.LongTensor:
+    sot_token = tokenizer.encoder["<|startoftext|>"]
+    eot_token = tokenizer.encoder["<|endoftext|>"]
+    tokens = [sot_token] + tokenizer.encode(caption) + [eot_token]
+
+    result = torch.zeros(text_length, dtype=torch.long)
+    if len(tokens) > text_length:
+        if truncate:
+            tokens = tokens[:text_length]
+            tokens[-1] = eot_token
+        else:
+            raise RuntimeError(
+                f"Input {caption} is too long for context length {text_length}"
+            )
+    result[:len(tokens)] = torch.tensor(tokens)
+    return result
+
+
+class ImageTextDataset(Dataset):
+    def __init__(self,
+                 dataset,
+                 transform=None,
+                 text_length: int = 77,
+                 truncate: bool = True):
+        self.dataset = dataset
+        self.transform = transform
+        self.text_length = text_length
+        self.truncate = truncate
+        self.tokenizer = SimpleTokenizer()
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, index):
+        pid, image_id, img_path, caption = self.dataset[index]
+        img = read_image(img_path)
+        if self.transform is not None:
+            img = self.transform(img)
+
+        tokens = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate)
+
+        ret = {
+            'pids': pid,
+            'image_ids': image_id,
+            'images': img,
+            'caption_ids': tokens,
+        }
+
+        return ret
+
+
+class ImageDataset(Dataset):
+    def __init__(self, image_pids, img_paths, transform=None):
+        self.image_pids = image_pids
+        self.img_paths = img_paths
+        self.transform = transform
+
+    def __len__(self):
+        return len(self.image_pids)
+
+    def __getitem__(self, index):
+        pid, img_path = self.image_pids[index], self.img_paths[index]
+        img = read_image(img_path)
+        if self.transform is not None:
+            img = self.transform(img)
+        return pid, img
+
+
+class TextDataset(Dataset):
+    def __init__(self,
+                 caption_pids,
+                 captions,
+                 text_length: int = 77,
+                 truncate: bool = True):
+        self.caption_pids = caption_pids
+        self.captions = captions
+        self.text_length = text_length
+        self.truncate = truncate
+        self.tokenizer = SimpleTokenizer(bpe_path="/home/linkslinks/文档/ai/GroupViT/datasets/bpe_simple_vocab_16e6.txt.gz")
+
+    def __len__(self):
+        return len(self.caption_pids)
+
+    def __getitem__(self, index):
+        pid, caption = self.caption_pids[index], self.captions[index]
+
+        caption = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate)
+
+        return pid, caption
+
+
+class ImageTextMLMDataset(Dataset):
+    def __init__(self,
+                 dataset,
+                 transform=None,
+                 text_length: int = 77,
+                 truncate: bool = True):
+        self.dataset = dataset
+        self.transform = transform
+        self.text_length = text_length
+        self.truncate = truncate
+
+        self.tokenizer = SimpleTokenizer()
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, index):
+        pid, image_id, img_path, caption = self.dataset[index]
+        img = read_image(img_path)
+        if self.transform is not None:
+            img = self.transform(img)
+        
+        caption_tokens = tokenize(caption, tokenizer=self.tokenizer, text_length=self.text_length, truncate=self.truncate)
+
+        mlm_tokens, mlm_labels = self._build_random_masked_tokens_and_labels(caption_tokens.cpu().numpy())
+
+        ret = {
+            'pids': pid,
+            'image_ids': image_id,
+            'images': img,
+            'caption_ids': caption_tokens,
+            'mlm_ids': mlm_tokens,
+            'mlm_labels': mlm_labels
+        }
+
+        return ret
+
+    def _build_random_masked_tokens_and_labels(self, tokens):
+        """
+        Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
+        :param tokens: list of int, tokenized sentence.
+        :return: (list of int, list of int), masked tokens and related labels for MLM prediction
+        """
+        mask = self.tokenizer.encoder["<|mask|>"]
+        token_range = list(range(1, len(self.tokenizer.encoder)-3)) # 1 ~ 49405
+        
+        labels = []
+        for i, token in enumerate(tokens):
+            if 0 < token < 49405:
+                prob = random.random()
+                # mask token with 15% probability
+                if prob < 0.15:
+                    prob /= 0.15
+
+                    # 80% randomly change token to mask token
+                    if prob < 0.8:
+                        tokens[i] = mask
+
+                    # 10% randomly change token to random token
+                    elif prob < 0.9:
+                        tokens[i] = random.choice(token_range)
+
+                    # -> rest 10% randomly keep current token
+
+                    # append current token to output (we will predict these later)
+                    labels.append(token)
+                else:
+                    # no masking token (will be ignored by loss function later)
+                    labels.append(0)
+            else:
+                labels.append(0)
+        
+        if all(l == 0 for l in labels):
+            # at least mask 1
+            labels[1] = tokens[1]
+            tokens[1] = mask
+
+        return torch.tensor(tokens), torch.tensor(labels)

+ 92 - 0
datasets/build.py

@@ -0,0 +1,92 @@
+import logging
+import torch
+import torchvision.transforms as T
+from torch.utils.data import DataLoader
+
+# from datasets.sampler import RandomIdentitySampler
+# from datasets.sampler_ddp import RandomIdentitySampler_DDP
+from utils.comm import get_world_size
+from .cuhkpedes import CUHKPEDES
+from .bases import ImageDataset, TextDataset, ImageTextDataset, ImageTextMLMDataset
+
+# __factory = {'CUHK-PEDES': CUHKPEDES, 'ICFG-PEDES': ICFGPEDES, 'RSTPReid': RSTPReid}
+__factory = {'CUHK-PEDES': CUHKPEDES}
+
+def build_transforms(img_size=(384, 128), aug=False, is_train=True):
+    height, width = img_size
+
+    mean = [0.48145466, 0.4578275, 0.40821073]
+    std = [0.26862954, 0.26130258, 0.27577711]
+
+    if not is_train:
+        transform = T.Compose([
+            T.Resize((height, width)),
+            T.ToTensor(),
+            T.Normalize(mean=mean, std=std),
+        ])
+        return transform
+
+    # transform for training
+    if aug:
+        transform = T.Compose([
+            T.Resize((height, width)),
+            T.RandomHorizontalFlip(0.5),
+            T.Pad(10),
+            T.RandomCrop((height, width)),
+            T.ToTensor(),
+            T.Normalize(mean=mean, std=std),
+            T.RandomErasing(scale=(0.02, 0.4), value=mean),
+        ])
+    else:
+        transform = T.Compose([
+            T.Resize((height, width)),
+            T.RandomHorizontalFlip(0.5),
+            T.ToTensor(),
+            T.Normalize(mean=mean, std=std),
+        ])
+    return transform
+
+def collate(batch):
+    keys = set([key for b in batch for key in b.keys()])
+    # turn list of dicts data structure to dict of lists data structure
+    dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys}
+
+    batch_tensor_dict = {}
+    for k, v in dict_batch.items():
+        if isinstance(v[0], int):
+            batch_tensor_dict.update({k: torch.tensor(v)})
+        elif torch.is_tensor(v[0]):
+             batch_tensor_dict.update({k: torch.stack(v)})
+        else:
+            raise TypeError(f"Unexpect data type: {type(v[0])} in a batch.")
+
+    return batch_tensor_dict
+
+def build_dataloader(args, tranforms=None):
+    logger = logging.getLogger("IRRA.dataset")
+
+    num_workers = args.data.num_workers
+    dataset = __factory[args.data.dataset.meta.cuhkpedes_val.name](root=args.data.dataset.meta.cuhkpedes_val.raw_path)
+    num_classes = len(dataset.train_id_container)
+
+    val_transforms = build_transforms(img_size=(args.data.img_aug.img_size * 3, args.data.img_aug.img_size),
+                                    is_train=False)
+
+    # use test set as validate set
+    ds = dataset.val
+    val_img_set = ImageDataset(ds['image_pids'], ds['img_paths'],
+                                val_transforms)
+    val_txt_set = TextDataset(ds['caption_pids'],
+                                ds['captions'],
+                                text_length=args.data.text_aug.max_seq_len)
+
+    val_img_loader = DataLoader(val_img_set,
+                                batch_size=args.data.batch_size,
+                                shuffle=False,
+                                num_workers=num_workers)
+    val_txt_loader = DataLoader(val_txt_set,
+                                batch_size=args.data.batch_size,
+                                shuffle=False,
+                                num_workers=num_workers)
+
+    return val_img_loader, val_txt_loader, num_classes

+ 114 - 0
datasets/cuhkpedes.py

@@ -0,0 +1,114 @@
+import os.path as op
+from typing import List
+
+from utils.iotools import read_json
+from .bases import BaseDataset
+
+
+class CUHKPEDES(BaseDataset):
+    """
+    CUHK-PEDES
+
+    Reference:
+    Person Search With Natural Language Description (CVPR 2017)
+
+    URL: https://openaccess.thecvf.com/content_cvpr_2017/html/Li_Person_Search_With_CVPR_2017_paper.html
+
+    Dataset statistics:
+    ### identities: 13003
+    ### images: 40206,  (train)  (test)  (val)
+    ### captions: 
+    ### 9 images have more than 2 captions
+    ### 4 identity have only one image
+
+    annotation format: 
+    [{'split', str,
+      'captions', list,
+      'file_path', str,
+      'processed_tokens', list,
+      'id', int}...]
+    """
+    dataset_dir = 'CUHK-PEDES'
+
+    def __init__(self, root='', verbose=True):
+        super(CUHKPEDES, self).__init__()
+        self.dataset_dir = op.join(root, self.dataset_dir)
+        self.img_dir = op.join(self.dataset_dir, 'imgs/')
+
+        self.anno_path = op.join(self.dataset_dir, 'reid_raw.json')
+        self._check_before_run()
+
+        self.train_annos, self.test_annos, self.val_annos = self._split_anno(self.anno_path)
+
+        self.train, self.train_id_container = self._process_anno(self.train_annos, training=True)
+        self.test, self.test_id_container = self._process_anno(self.test_annos)
+        self.val, self.val_id_container = self._process_anno(self.val_annos)
+
+        if verbose:
+            self.logger.info("=> CUHK-PEDES Images and Captions are loaded")
+            self.show_dataset_info()
+
+
+    def _split_anno(self, anno_path: str):
+        train_annos, test_annos, val_annos = [], [], []
+        annos = read_json(anno_path)
+        for anno in annos:
+            if anno['split'] == 'train':
+                train_annos.append(anno)
+            elif anno['split'] == 'test':
+                test_annos.append(anno)
+            else:
+                val_annos.append(anno)
+        return train_annos, test_annos, val_annos
+
+  
+    def _process_anno(self, annos: List[dict], training=False):
+        pid_container = set()
+        if training:
+            dataset = []
+            image_id = 0
+            for anno in annos:
+                pid = int(anno['id']) - 1 # make pid begin from 0
+                pid_container.add(pid)
+                img_path = op.join(self.img_dir, anno['file_path'])
+                captions = anno['captions'] # caption list
+                for caption in captions:
+                    dataset.append((pid, image_id, img_path, caption))
+                image_id += 1
+            for idx, pid in enumerate(pid_container):
+                # check pid begin from 0 and no break
+                assert idx == pid, f"idx: {idx} and pid: {pid} are not match"
+            return dataset, pid_container
+        else:
+            dataset = {}
+            img_paths = []
+            captions = []
+            image_pids = []
+            caption_pids = []
+            for anno in annos:
+                pid = int(anno['id'])
+                pid_container.add(pid)
+                img_path = op.join(self.img_dir, anno['file_path'])
+                img_paths.append(img_path)
+                image_pids.append(pid)
+                caption_list = anno['captions'] # caption list
+                for caption in caption_list:
+                    captions.append(caption)
+                    caption_pids.append(pid)
+            dataset = {
+                "image_pids": image_pids,
+                "img_paths": img_paths,
+                "caption_pids": caption_pids,
+                "captions": captions
+            }
+            return dataset, pid_container
+
+
+    def _check_before_run(self):
+        """Check if all files are available before going deeper"""
+        if not op.exists(self.dataset_dir):
+            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
+        if not op.exists(self.img_dir):
+            raise RuntimeError("'{}' is not available".format(self.img_dir))
+        if not op.exists(self.anno_path):
+            raise RuntimeError("'{}' is not available".format(self.anno_path))

+ 67 - 0
datasets/sampler.py

@@ -0,0 +1,67 @@
+from torch.utils.data.sampler import Sampler
+from collections import defaultdict
+import copy
+import random
+import numpy as np
+
+class RandomIdentitySampler(Sampler):
+    """
+    Randomly sample N identities, then for each identity,
+    randomly sample K instances, therefore batch size is N*K.
+    Args:
+    - data_source (list): list of (img_path, pid, camid).
+    - num_instances (int): number of instances per identity in a batch.
+    - batch_size (int): number of examples in a batch.
+    """
+
+    def __init__(self, data_source, batch_size, num_instances):
+        self.data_source = data_source
+        self.batch_size = batch_size
+        self.num_instances = num_instances
+        self.num_pids_per_batch = self.batch_size // self.num_instances
+        self.index_dic = defaultdict(list) #dict with list value
+        #{783: [0, 5, 116, 876, 1554, 2041],...,}
+        for index, (pid, _, _, _) in enumerate(self.data_source):
+            self.index_dic[pid].append(index)
+        self.pids = list(self.index_dic.keys())
+
+        # estimate number of examples in an epoch
+        self.length = 0
+        for pid in self.pids:
+            idxs = self.index_dic[pid]
+            num = len(idxs)
+            if num < self.num_instances:
+                num = self.num_instances
+            self.length += num - num % self.num_instances
+
+    def __iter__(self):
+        batch_idxs_dict = defaultdict(list)
+
+        for pid in self.pids:
+            idxs = copy.deepcopy(self.index_dic[pid])
+            if len(idxs) < self.num_instances:
+                idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
+            random.shuffle(idxs)
+            batch_idxs = []
+            for idx in idxs:
+                batch_idxs.append(idx)
+                if len(batch_idxs) == self.num_instances:
+                    batch_idxs_dict[pid].append(batch_idxs)
+                    batch_idxs = []
+
+        avai_pids = copy.deepcopy(self.pids)
+        final_idxs = []
+
+        while len(avai_pids) >= self.num_pids_per_batch:
+            selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
+            for pid in selected_pids:
+                batch_idxs = batch_idxs_dict[pid].pop(0)
+                final_idxs.extend(batch_idxs)
+                if len(batch_idxs_dict[pid]) == 0:
+                    avai_pids.remove(pid)
+
+        return iter(final_idxs)
+
+    def __len__(self):
+        return self.length
+

+ 197 - 0
datasets/sampler_ddp.py

@@ -0,0 +1,197 @@
+from torch.utils.data.sampler import Sampler
+from collections import defaultdict
+import copy
+import random
+import numpy as np
+import math
+import torch.distributed as dist
+_LOCAL_PROCESS_GROUP = None
+import torch
+import pickle
+
+def _get_global_gloo_group():
+    """
+    Return a process group based on gloo backend, containing all the ranks
+    The result is cached.
+    """
+    if dist.get_backend() == "nccl":
+        return dist.new_group(backend="gloo")
+    else:
+        return dist.group.WORLD
+
+def _serialize_to_tensor(data, group):
+    backend = dist.get_backend(group)
+    assert backend in ["gloo", "nccl"]
+    device = torch.device("cpu" if backend == "gloo" else "cuda")
+
+    buffer = pickle.dumps(data)
+    if len(buffer) > 1024 ** 3:
+        print(
+            "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
+                dist.get_rank(), len(buffer) / (1024 ** 3), device
+            )
+        )
+    storage = torch.ByteStorage.from_buffer(buffer)
+    tensor = torch.ByteTensor(storage).to(device=device)
+    return tensor
+
+def _pad_to_largest_tensor(tensor, group):
+    """
+    Returns:
+        list[int]: size of the tensor, on each rank
+        Tensor: padded tensor that has the max size
+    """
+    world_size = dist.get_world_size(group=group)
+    assert (
+            world_size >= 1
+    ), "comm.gather/all_gather must be called from ranks within the given group!"
+    local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
+    size_list = [
+        torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
+    ]
+    dist.all_gather(size_list, local_size, group=group)
+    size_list = [int(size.item()) for size in size_list]
+
+    max_size = max(size_list)
+
+    # we pad the tensor because torch all_gather does not support
+    # gathering tensors of different shapes
+    if local_size != max_size:
+        padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
+        tensor = torch.cat((tensor, padding), dim=0)
+    return size_list, tensor
+
+def all_gather(data, group=None):
+    """
+    Run all_gather on arbitrary picklable data (not necessarily tensors).
+    Args:
+        data: any picklable object
+        group: a torch process group. By default, will use a group which
+            contains all ranks on gloo backend.
+    Returns:
+        list[data]: list of data gathered from each rank
+    """
+    if dist.get_world_size() == 1:
+        return [data]
+    if group is None:
+        group = _get_global_gloo_group()
+    if dist.get_world_size(group) == 1:
+        return [data]
+
+    tensor = _serialize_to_tensor(data, group)
+
+    size_list, tensor = _pad_to_largest_tensor(tensor, group)
+    max_size = max(size_list)
+
+    # receiving Tensor from all ranks
+    tensor_list = [
+        torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+    ]
+    dist.all_gather(tensor_list, tensor, group=group)
+
+    data_list = []
+    for size, tensor in zip(size_list, tensor_list):
+        buffer = tensor.cpu().numpy().tobytes()[:size]
+        data_list.append(pickle.loads(buffer))
+
+    return data_list
+
+def shared_random_seed():
+    """
+    Returns:
+        int: a random number that is the same across all workers.
+            If workers need a shared RNG, they can use this shared seed to
+            create one.
+    All workers must call this function, otherwise it will deadlock.
+    """
+    ints = np.random.randint(2 ** 31)
+    all_ints = all_gather(ints)
+    return all_ints[0]
+
+class RandomIdentitySampler_DDP(Sampler):
+    """
+    Randomly sample N identities, then for each identity,
+    randomly sample K instances, therefore batch size is N*K.
+    Args:
+    - data_source (list): list of (img_path, pid, camid).
+    - num_instances (int): number of instances per identity in a batch.
+    - batch_size (int): number of examples in a batch.
+    """
+
+    def __init__(self, data_source, batch_size, num_instances):
+        self.data_source = data_source
+        self.batch_size = batch_size
+        self.world_size = dist.get_world_size()
+        self.num_instances = num_instances
+        self.mini_batch_size = self.batch_size // self.world_size
+        self.num_pids_per_batch = self.mini_batch_size // self.num_instances
+        self.index_dic = defaultdict(list)
+
+        for index, (pid, _, _, _) in enumerate(self.data_source):
+            self.index_dic[pid].append(index)
+        self.pids = list(self.index_dic.keys())
+
+        # estimate number of examples in an epoch
+        self.length = 0
+        for pid in self.pids:
+            idxs = self.index_dic[pid]
+            num = len(idxs)
+            if num < self.num_instances:
+                num = self.num_instances
+            self.length += num - num % self.num_instances
+
+        self.rank = dist.get_rank()
+        #self.world_size = dist.get_world_size()
+        self.length //= self.world_size
+
+    def __iter__(self):
+        seed = shared_random_seed()
+        np.random.seed(seed)
+        self._seed = int(seed)
+        final_idxs = self.sample_list()
+        length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size))
+        #final_idxs = final_idxs[self.rank * length:(self.rank + 1) * length]
+        final_idxs = self.__fetch_current_node_idxs(final_idxs, length)
+        self.length = len(final_idxs)
+        return iter(final_idxs)
+
+
+    def __fetch_current_node_idxs(self, final_idxs, length):
+        total_num = len(final_idxs)
+        block_num = (length // self.mini_batch_size)
+        index_target = []
+        for i in range(0, block_num * self.world_size, self.world_size):
+            index = range(self.mini_batch_size * self.rank + self.mini_batch_size * i, min(self.mini_batch_size * self.rank + self.mini_batch_size * (i+1), total_num))
+            index_target.extend(index)
+        index_target_npy = np.array(index_target)
+        final_idxs = list(np.array(final_idxs)[index_target_npy])
+        return final_idxs
+
+
+    def sample_list(self):
+        #np.random.seed(self._seed)
+        avai_pids = copy.deepcopy(self.pids)
+        batch_idxs_dict = {}
+
+        batch_indices = []
+        while len(avai_pids) >= self.num_pids_per_batch:
+            selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist()
+            for pid in selected_pids:
+                if pid not in batch_idxs_dict:
+                    idxs = copy.deepcopy(self.index_dic[pid])
+                    if len(idxs) < self.num_instances:
+                        idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist()
+                    np.random.shuffle(idxs)
+                    batch_idxs_dict[pid] = idxs
+
+                avai_idxs = batch_idxs_dict[pid]
+                for _ in range(self.num_instances):
+                    batch_indices.append(avai_idxs.pop(0))
+
+                if len(avai_idxs) < self.num_instances: avai_pids.remove(pid)
+
+        return batch_indices
+
+    def __len__(self):
+        return self.length
+

+ 71 - 3
main_group_vit.py

@@ -45,10 +45,15 @@ from mmseg.apis import multi_gpu_test
 from models import build_model
 from omegaconf import OmegaConf, read_write
 from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
+from datasets.build import build_dataloader
 from timm.utils import AverageMeter, accuracy
 from utils import (auto_resume_helper, build_dataset_class_tokens, build_optimizer, build_scheduler, data2cuda,
                    get_config, get_grad_norm, get_logger, load_checkpoint, parse_losses, reduce_tensor, save_checkpoint)
 
+from tools.cfg2arg import cfg2arg
+from utils.metrics import Evaluator
+
+
 try:
     # noinspection PyUnresolvedReferences
     from apex import amp
@@ -103,6 +108,14 @@ def train(cfg):
         data_loader_train, data_loader_val = build_loader(cfg.data)
     data_loader_seg = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
 
+    print("\n\n\n")
+    print(cfg)
+    print("\n\n\n")
+
+    # get image-text pair datasets dataloader
+    # train_loader, val_img_loader, val_txt_loader, num_classes = build_dataloader(cfg)
+    val_img_loader, val_txt_loader, num_classes = build_dataloader(cfg)
+
     logger = get_logger()
 
     logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
@@ -148,8 +161,8 @@ def train(cfg):
         else:
             logger.info(f'no checkpoint found in {cfg.output}, ignoring auto resume')
 
-    max_accuracy = max_miou = 0.0
-    max_metrics = {'max_accuracy': max_accuracy, 'max_miou': max_miou}
+    max_accuracy = max_miou = max_rank1 = 0.0
+    max_metrics = {'max_accuracy': max_accuracy, 'max_miou': max_miou, 'max_rank1': max_rank1}
 
     if cfg.checkpoint.resume:
         max_metrics = load_checkpoint(cfg, model_without_ddp, optimizer, lr_scheduler)
@@ -160,6 +173,11 @@ def train(cfg):
         if 'seg' in cfg.evaluate.task:
             miou = validate_seg(cfg, data_loader_seg, model)
             logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
+        if 'reid' in cfg.evaluate.task:
+            # mrank1 = validate_reid(cfg, data_loader_reid, model)
+            mrank1 = validate_reid(cfg, val_img_loader, val_txt_loader, model)
+            # logger.info(f'Rank1 of the network on the {len(data_loader_reid)} test images: {mrank1:.2f}%')
+            logger.info(f'Rank1 of the network on the {len(val_img_loader)} test images: {mrank1:.2f}%')
         if cfg.evaluate.eval_only:
             return
 
@@ -170,7 +188,8 @@ def train(cfg):
         if dist.get_rank() == 0 and (epoch % cfg.checkpoint.save_freq == 0 or epoch == (cfg.train.epochs - 1)):
             save_checkpoint(cfg, epoch, model_without_ddp, {
                 'max_accuracy': max_accuracy,
-                'max_miou': max_miou
+                'max_miou': max_miou,
+                'max_rank1': max_rank1
             }, optimizer, lr_scheduler)
         dist.barrier()
         loss_train = loss_train_dict['total_loss']
@@ -198,6 +217,16 @@ def train(cfg):
                 dist.barrier()
                 max_miou = max_metrics['max_miou']
                 logger.info(f'Max mIoU: {max_miou:.2f}%')
+            if 'reid' in cfg.evaluate.task:
+                mrank1 = validate_reid(cfg, val_img_loader, val_txt_loader, model)
+                logger.info(f'mRank1 of the network on the {len(val_img_loader)} test images: {mrank1:.2f}%')
+                max_metrics['max_rank1'] = max(max_metrics['max_rank1'], mrank1)
+                if cfg.evaluate.reid.save_best and dist.get_rank() == 0 and mrank1 > max_rank1:
+                    save_checkpoint(
+                        cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_rank1')
+                dist.barrier()
+                max_rank1 = max_metrics['max_rank1']
+                logger.info(f'Max mRank1: {max_rank1:.2f}%')
 
         if wandb is not None:
             log_stat = {f'epoch/train_{k}': v for k, v in loss_train_dict.items()}
@@ -206,6 +235,7 @@ def train(cfg):
                 'epoch/val_acc5': acc5,
                 'epoch/val_loss': loss,
                 'epoch/val_miou': miou,
+                'epoch/val_rank1': mrank1,
                 'epoch/epoch': epoch,
                 'epoch/n_parameters': n_parameters
             })
@@ -413,6 +443,44 @@ def validate_seg(config, data_loader, model):
     return miou_result
 
 
+@torch.no_grad()
+def validate_reid(cfg, img_loader, txt_loader, model):
+    logger = get_logger()
+    dist.barrier()
+    # model.eval()
+    evaluator = Evaluator(img_loader, txt_loader)
+
+    if hasattr(model, 'module'):
+        model_without_ddp = model.module
+    else:
+        model_without_ddp = model
+
+    # reid_model = build_reid_inference(model_without_ddp, img_loader, txt_loader, cfg.evaluate.reid)
+
+    # mmddp_model = MMDistributedDataParallel(
+    #     reid_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
+    rank1 = evaluator.eval(model_without_ddp.eval())
+    # results = multi_gpu_test(
+    #     model=mmddp_model,
+    #     data_loader=img_loader,
+    #     tmpdir=None,
+    #     gpu_collect=True,
+    #     efficient_test=False,
+    #     pre_eval=True,
+    #     format_only=False)
+
+    # if dist.get_rank() == 0:
+    #     metric = [img_loader.dataset.evaluate(results, metric='Rank-1')]
+    # else:
+    #     metric = [None]
+    # dist.broadcast_object_list(metric)
+    # rank1_result = metric[0]['Rank-1'] * 100
+
+    torch.cuda.empty_cache()
+
+    return rank1
+
+
 def main():
     args = parse_args()
     cfg = get_config(args)

+ 5 - 0
run.sh

@@ -0,0 +1,5 @@
+./tools/dist_launch.sh \
+main_group_vit.py \
+configs/group_vit_gcc_redcap_cuhkpedes_30e.yml \
+checkpoint/group_vit_gcc_redcap_30e-3dd09a76.pth \
+1

+ 62 - 0
tools/cfg2arg.py

@@ -0,0 +1,62 @@
+import argparse
+
+
+def cfg2arg(cfg):
+    # 定义argparse对象
+    parser = argparse.ArgumentParser()
+
+    # 添加参数
+    parser.add_argument('--local_rank', type=int, default=cfg['local_rank'])
+    parser.add_argument('--name', type=str, default=cfg['model_name'])
+    parser.add_argument('--output_dir', type=str, default=cfg['output'])
+    parser.add_argument('--log_period', type=int, default=cfg['print_freq'])
+    parser.add_argument('--eval_period', type=int, default=cfg['evaluate']['eval_freq'])
+    parser.add_argument('--val_dataset', type=str, default=cfg['data']['dataset']['val'][0])
+    parser.add_argument('--resume', type=bool, default=cfg['checkpoint']['auto_resume'])
+    parser.add_argument('--resume_ckpt_file', type=str, default=cfg['checkpoint']['resume'])
+    parser.add_argument('--pretrain_choice', type=str, default='ViT-B/16')  # 这里假设预训练选择是固定的
+    parser.add_argument('--temperature', type=float, default=cfg['model']['contrast_temperature'])
+    parser.add_argument('--img_aug', type=bool, default=cfg['data']['img_aug']['deit_aug'])
+    parser.add_argument('--cmt_depth', type=int, default=4)  # 这里假设cmt_depth是固定的
+    parser.add_argument('--masked_token_rate', type=float, default=0.8)  # 这里假设masked_token_rate是固定的
+    parser.add_argument('--masked_token_unchanged_rate', type=float, default=0.1)  # 这里假设masked_token_unchanged_rate是固定的
+    parser.add_argument('--lr_factor', type=float, default=5.0)  # 这里假设lr_factor是固定的
+    parser.add_argument('--MLM', type=bool, default=True)  # 这里假设MLM是固定的
+    parser.add_argument('--loss_names', type=str, default='sdm+mlm+id')  # 这里假设loss_names是固定的
+    parser.add_argument('--mlm_loss_weight', type=float, default=1.0)  # 这里假设mlm_loss_weight是固定的
+    parser.add_argument('--id_loss_weight', type=float, default=1.0)  # 这里假设id_loss_weight是固定的
+    parser.add_argument('--img_size', type=tuple, default=(cfg['data']['img_aug']['img_size'], cfg['data']['img_aug']['img_size']))
+    parser.add_argument('--stride_size', type=int, default=16)  # 这里假设stride_size是固定的
+    parser.add_argument('--text_length', type=int, default=cfg['data']['text_aug']['max_seq_len'])
+    parser.add_argument('--vocab_size', type=int, default=cfg['model']['text_encoder']['vocab_size'])
+    parser.add_argument('--optimizer', type=str, default=cfg['train']['optimizer']['name'])
+    parser.add_argument('--lr', type=float, default=cfg['train']['base_lr'])
+    parser.add_argument('--bias_lr_factor', type=float, default=2.0)  # 这里假设bias_lr_factor是固定的
+    parser.add_argument('--momentum', type=float, default=0.9)  # 这里假设momentum是固定的
+    parser.add_argument('--weight_decay', type=float, default=cfg['train']['weight_decay'])
+    parser.add_argument('--weight_decay_bias', type=float, default=0.0)  # 这里假设weight_decay_bias是固定的
+    parser.add_argument('--alpha', type=float, default=0.9)  # 这里假设alpha是固定的
+    parser.add_argument('--beta', type=float, default=0.999)  # 这里假设beta是固定的
+    parser.add_argument('--num_epoch', type=int, default=cfg['train']['epochs'])
+    parser.add_argument('--milestones', type=tuple, default=(20, 50))  # 这里假设milestones是固定的
+    parser.add_argument('--gamma', type=float, default=0.1)  # 这里假设gamma是固定的
+    parser.add_argument('--warmup_factor', type=float, default=0.1)  # 这里假设warmup_factor是固定的
+    parser.add_argument('--warmup_epochs', type=int, default=cfg['train']['warmup_epochs'])
+    parser.add_argument('--warmup_method', type=str, default='linear')  # 这里假设warmup_method是固定的
+    parser.add_argument('--lrscheduler', type=str, default=cfg['train']['lr_scheduler']['name'])
+    parser.add_argument('--target_lr', type=float, default=0)  # 这里假设target_lr是固定的
+    parser.add_argument('--power', type=float, default=0.9)  # 这里假设power是固定的
+    parser.add_argument('--dataset_name', type=str, default='CUHK-PEDES')  # 这里假设dataset_name是固定的
+    parser.add_argument('--sampler', type=str, default='random')  # 这里假设sampler是固定的
+    parser.add_argument('--num_instance', type=int, default=4)  # 这里假设num_instance是固定的
+    parser.add_argument('--root_dir', type=str, default='/home/linkslinks/dataset')  # 这里假设root_dir是固定的
+    parser.add_argument('--batch_size', type=int, default=cfg['data']['batch_size'])
+    parser.add_argument('--test_batch_size', type=int, default=512)  # 这里假设test_batch_size是固定的
+    parser.add_argument('--num_workers', type=int, default=cfg['data']['num_workers'])
+    parser.add_argument('--training', type=bool, default=True)  # 这里假设training是固定的
+    parser.add_argument('--distributed', type=bool, default=False)  # 这里假设distributed是固定的
+
+    # 解析参数
+    args = parser.parse_args()
+
+    return args

+ 116 - 0
utils/comm.py

@@ -0,0 +1,116 @@
+"""
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import pickle
+
+import torch
+import torch.distributed as dist
+
+
+def get_world_size():
+    if not dist.is_available():
+        return 1
+    if not dist.is_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    if not dist.is_available():
+        return 0
+    if not dist.is_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def synchronize():
+    """
+    Helper function to synchronize (barrier) among all processes when
+    using distributed training
+    """
+    if not dist.is_available():
+        return
+    if not dist.is_initialized():
+        return
+    world_size = dist.get_world_size()
+    if world_size == 1:
+        return
+    dist.barrier()
+
+
+def all_gather(data):
+    """
+    Run all_gather on arbitrary picklable data (not necessarily tensors)
+    Args:
+        data: any picklable object
+    Returns:
+        list[data]: list of data gathered from each rank
+    """
+    world_size = get_world_size()
+    if world_size == 1:
+        return [data]
+
+    # serialized to a Tensor
+    buffer = pickle.dumps(data)
+    storage = torch.ByteStorage.from_buffer(buffer)
+    tensor = torch.ByteTensor(storage).to("cuda")
+
+    # obtain Tensor size of each rank
+    local_size = torch.IntTensor([tensor.numel()]).to("cuda")
+    size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
+    dist.all_gather(size_list, local_size)
+    size_list = [int(size.item()) for size in size_list]
+    max_size = max(size_list)
+
+    # receiving Tensor from all ranks
+    # we pad the tensor because torch all_gather does not support
+    # gathering tensors of different shapes
+    tensor_list = []
+    for _ in size_list:
+        tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
+    if local_size != max_size:
+        padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
+        tensor = torch.cat((tensor, padding), dim=0)
+    dist.all_gather(tensor_list, tensor)
+
+    data_list = []
+    for size, tensor in zip(size_list, tensor_list):
+        buffer = tensor.cpu().numpy().tobytes()[:size]
+        data_list.append(pickle.loads(buffer))
+
+    return data_list
+
+
+def reduce_dict(input_dict, average=True):
+    """
+    Args:
+        input_dict (dict): all the values will be reduced
+        average (bool): whether to do average or sum
+    Reduce the values in the dictionary from all processes so that process with rank
+    0 has the averaged results. Returns a dict with the same fields as
+    input_dict, after reduction.
+    """
+    world_size = get_world_size()
+    if world_size < 2:
+        return input_dict
+    with torch.no_grad():
+        names = []
+        values = []
+        # sort the keys so that they are consistent across processes
+        for k in sorted(input_dict.keys()):
+            names.append(k)
+            values.append(input_dict[k])
+        values = torch.stack(values, dim=0)
+        dist.reduce(values, dst=0)
+        if dist.get_rank() == 0 and average:
+            # only main process gets accumulated, so only divide by
+            # world_size in this case
+            values /= world_size
+        reduced_dict = {k: v for k, v in zip(names, values)}
+    return reduced_dict

+ 76 - 0
utils/iotools.py

@@ -0,0 +1,76 @@
+# encoding: utf-8
+"""
+@author:  sherlock
+@contact: sherlockliao01@gmail.com
+"""
+from PIL import Image, ImageFile
+import errno
+import json
+import pickle as pkl
+import os
+import os.path as osp
+import yaml
+from easydict import EasyDict as edict
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+
+def read_image(img_path):
+    """Keep reading image until succeed.
+    This can avoid IOError incurred by heavy IO process."""
+    got_img = False
+    if not osp.exists(img_path):
+        raise IOError("{} does not exist".format(img_path))
+    while not got_img:
+        try:
+            img = Image.open(img_path).convert('RGB')
+            got_img = True
+        except IOError:
+            print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
+            pass
+    return img
+
+
+def mkdir_if_missing(directory):
+    if not osp.exists(directory):
+        try:
+            os.makedirs(directory)
+        except OSError as e:
+            if e.errno != errno.EEXIST:
+                raise
+
+
+def check_isfile(path):
+    isfile = osp.isfile(path)
+    if not isfile:
+        print("=> Warning: no file found at '{}' (ignored)".format(path))
+    return isfile
+
+
+def read_json(fpath):
+    with open(fpath, 'r') as f:
+        obj = json.load(f)
+    return obj
+
+
+def write_json(obj, fpath):
+    mkdir_if_missing(osp.dirname(fpath))
+    with open(fpath, 'w') as f:
+        json.dump(obj, f, indent=4, separators=(',', ': '))
+
+
+def get_text_embedding(path, length):
+    with open(path, 'rb') as f:
+        word_frequency = pkl.load(f)
+
+
+def save_train_configs(path, args):
+    if not os.path.exists(path):
+        os.makedirs(path)
+    with open(f'{path}/configs.yaml', 'w') as f:
+        yaml.dump(vars(args), f, default_flow_style=False)
+
+def load_train_configs(path):
+    with open(path, 'r') as f:
+        args = yaml.load(f, Loader=yaml.FullLoader)
+    return edict(args)

+ 103 - 0
utils/metrics.py

@@ -0,0 +1,103 @@
+from prettytable import PrettyTable
+import torch
+import numpy as np
+import os
+import torch.nn.functional as F
+import logging
+
+
+def rank(similarity, q_pids, g_pids, max_rank=10, get_mAP=True):
+    if get_mAP:
+        indices = torch.argsort(similarity, dim=1, descending=True)
+    else:
+        # acclerate sort with topk
+        _, indices = torch.topk(
+            similarity, k=max_rank, dim=1, largest=True, sorted=True
+        )  # q * topk
+    pred_labels = g_pids[indices.cpu()]  # q * k
+    matches = pred_labels.eq(q_pids.view(-1, 1))  # q * k
+
+    all_cmc = matches[:, :max_rank].cumsum(1) # cumulative sum
+    all_cmc[all_cmc > 1] = 1
+    all_cmc = all_cmc.float().mean(0) * 100
+    # all_cmc = all_cmc[topk - 1]
+
+    if not get_mAP:
+        return all_cmc, indices
+
+    num_rel = matches.sum(1)  # q
+    tmp_cmc = matches.cumsum(1)  # q * k
+
+    inp = [tmp_cmc[i][match_row.nonzero()[-1]] / (match_row.nonzero()[-1] + 1.) for i, match_row in enumerate(matches)]
+    mINP = torch.cat(inp).mean() * 100
+
+    tmp_cmc = [tmp_cmc[:, i] / (i + 1.0) for i in range(tmp_cmc.shape[1])]
+    tmp_cmc = torch.stack(tmp_cmc, 1) * matches
+    AP = tmp_cmc.sum(1) / num_rel  # q
+    mAP = AP.mean() * 100
+
+    return all_cmc, mAP, mINP, indices
+
+
+class Evaluator():
+    def __init__(self, img_loader, txt_loader):
+        self.img_loader = img_loader # gallery
+        self.txt_loader = txt_loader # query
+        self.logger = logging.getLogger("IRRA.eval")
+
+    def _compute_embedding(self, model):
+        model = model.eval()
+        device = next(model.parameters()).device
+
+        qids, gids, qfeats, gfeats = [], [], [], []
+        # text
+        for pid, caption in self.txt_loader:
+            print('pid', pid.shape[0])
+            print('caption: ', caption.shape[0])
+            caption = caption.to(device)
+            with torch.no_grad():
+                text_feat = model.encode_text(caption)
+            qids.append(pid.view(-1)) # flatten 
+            qfeats.append(text_feat)
+        qids = torch.cat(qids, 0)
+        qfeats = torch.cat(qfeats, 0)
+
+        # image
+        for pid, img in self.img_loader:
+            img = img.to(device)
+            with torch.no_grad():
+                img_feat = model.encode_image(img)
+            gids.append(pid.view(-1)) # flatten 
+            gfeats.append(img_feat)
+        gids = torch.cat(gids, 0)
+        gfeats = torch.cat(gfeats, 0)
+
+        return qfeats, gfeats, qids, gids
+    
+    def eval(self, model, i2t_metric=False):
+
+        qfeats, gfeats, qids, gids = self._compute_embedding(model)
+
+        qfeats = F.normalize(qfeats, p=2, dim=1) # text features
+        gfeats = F.normalize(gfeats, p=2, dim=1) # image features
+
+        similarity = qfeats @ gfeats.t()
+
+        t2i_cmc, t2i_mAP, t2i_mINP, _ = rank(similarity=similarity, q_pids=qids, g_pids=gids, max_rank=10, get_mAP=True)
+        t2i_cmc, t2i_mAP, t2i_mINP = t2i_cmc.numpy(), t2i_mAP.numpy(), t2i_mINP.numpy()
+        table = PrettyTable(["task", "R1", "R5", "R10", "mAP", "mINP"])
+        table.add_row(['t2i', t2i_cmc[0], t2i_cmc[4], t2i_cmc[9], t2i_mAP, t2i_mINP])
+
+        if i2t_metric:
+            i2t_cmc, i2t_mAP, i2t_mINP, _ = rank(similarity=similarity.t(), q_pids=gids, g_pids=qids, max_rank=10, get_mAP=True)
+            i2t_cmc, i2t_mAP, i2t_mINP = i2t_cmc.numpy(), i2t_mAP.numpy(), i2t_mINP.numpy()
+            table.add_row(['i2t', i2t_cmc[0], i2t_cmc[4], i2t_cmc[9], i2t_mAP, i2t_mINP])
+        # table.float_format = '.4'
+        table.custom_format["R1"] = lambda f, v: f"{v:.3f}"
+        table.custom_format["R5"] = lambda f, v: f"{v:.3f}"
+        table.custom_format["R10"] = lambda f, v: f"{v:.3f}"
+        table.custom_format["mAP"] = lambda f, v: f"{v:.3f}"
+        table.custom_format["mINP"] = lambda f, v: f"{v:.3f}"
+        self.logger.info('\n' + str(table))
+        
+        return t2i_cmc[0]

+ 135 - 0
utils/simple_tokenizer.py

@@ -0,0 +1,135 @@
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+    return os.path.join(os.path.dirname(os.path.abspath(__file__)), "../data/bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a signficant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8+n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+    """Return set of symbol pairs in a word.
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+def basic_clean(text):
+    text = ftfy.fix_text(text)
+    text = html.unescape(html.unescape(text))
+    return text.strip()
+
+
+def whitespace_clean(text):
+    text = re.sub(r'\s+', ' ', text)
+    text = text.strip()
+    return text
+
+
+class SimpleTokenizer(object):
+    def __init__(self, bpe_path: str = default_bpe()):
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+        merges = merges[1:49152-256-2+1]
+        merges = [tuple(merge.split()) for merge in merges]
+        vocab = list(bytes_to_unicode().values())
+        vocab = vocab + [v+'</w>' for v in vocab]
+        for merge in merges:
+            vocab.append(''.join(merge))
+        
+        vocab.pop(-1) # remove last one in vocab(jekyll) to keep vocab_size unchanged
+        vocab.extend(['<|mask|>', '<|startoftext|>', '<|endoftext|>']) # vocab_size 49408
+        # vocab.extend(['<|startoftext|>', '<|endoftext|>']) # vocab_size 49408
+        self.encoder = dict(zip(vocab, range(len(vocab))))
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.bpe_ranks = dict(zip(merges, range(len(merges))))
+        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|mask|>': '<|mask|>', '<|endoftext|>': '<|endoftext|>'}
+        self.pat = re.compile(r"""<\|startoftext\|>|<\|mask\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token+'</w>'
+
+        while True:
+            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word)-1 and word[i+1] == second:
+                    new_word.append(first+second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = ' '.join(word)
+        self.cache[token] = word
+        return word
+
+    def encode(self, text):
+        bpe_tokens = []
+        text = whitespace_clean(basic_clean(text)).lower()
+        for token in re.findall(self.pat, text):
+            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+        return bpe_tokens
+
+    def decode(self, tokens):
+        text = ''.join([self.decoder[token] for token in tokens])
+        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
+        return text