jazzcharles 2 anos atrás
pai
commit
3267f33fdf
77 arquivos alterados com 9807 adições e 2 exclusões
  1. 4 0
      .gitignore
  2. 97 0
      LICENSE
  3. 109 2
      README.md
  4. 115 0
      configs/default.yml
  5. 114 0
      configs/ovsegmentor/debug.yml
  6. 114 0
      configs/ovsegmentor/ovsegmentor_pretrain_vit_bert_stage1.yml
  7. 123 0
      configs/ovsegmentor/ovsegmentor_pretrain_vit_bert_stage2.yml
  8. 121 0
      configs/test_ade20k.yml
  9. 123 0
      configs/test_coco.yml
  10. 121 0
      configs/test_coco_stuff.yml
  11. 118 0
      configs/test_voc12.yml
  12. 123 0
      configs/test_voc_context.yml
  13. 288 0
      convert_dataset/convert_coco_object.py
  14. 119 0
      convert_dataset/convert_yfcc14m.py
  15. 114 0
      convert_dataset/create_subset.py
  16. 71 0
      convert_dataset/process_redcaps.py
  17. 20 0
      datasets/__init__.py
  18. 119 0
      datasets/base_dataset.py
  19. BIN
      datasets/bpe_simple_vocab_16e6.txt.gz
  20. 368 0
      datasets/builder.py
  21. 457 0
      datasets/clip_dataset.py
  22. 36 0
      datasets/formatting.py
  23. 46 0
      datasets/image_reader.py
  24. 267 0
      datasets/imagenet_template.py
  25. 247 0
      datasets/sampler.py
  26. 170 0
      datasets/tokenizer.py
  27. BIN
      figs/model.png
  28. 628 0
      main_pretrain.py
  29. 273 0
      main_seg.py
  30. 23 0
      models/__init__.py
  31. 24 0
      models/builder.py
  32. 537 0
      models/clipmodel.py
  33. 1014 0
      models/group_vit.py
  34. 175 0
      models/losses.py
  35. 80 0
      models/misc.py
  36. 811 0
      models/multi_label_contrastive.py
  37. 319 0
      models/transformer.py
  38. 60 0
      models/utils.py
  39. 314 0
      models/vision_transformer.py
  40. 24 0
      requirements.txt
  41. BIN
      segmentation/.DS_Store
  42. BIN
      segmentation/configs/.DS_Store
  43. BIN
      segmentation/configs/_base_/.DS_Store
  44. 15 0
      segmentation/configs/_base_/custom_import.py
  45. 47 0
      segmentation/configs/_base_/datasets/ade20k.py
  46. 45 0
      segmentation/configs/_base_/datasets/coco.py
  47. 46 0
      segmentation/configs/_base_/datasets/coco_stuff.py
  48. 45 0
      segmentation/configs/_base_/datasets/pascal_context.py
  49. 58 0
      segmentation/configs/_base_/datasets/pascal_voc12 copy.py
  50. 46 0
      segmentation/configs/_base_/datasets/pascal_voc12.py
  51. 23 0
      segmentation/datasets/__init__.py
  52. 48 0
      segmentation/datasets/ade20k.py
  53. 51 0
      segmentation/datasets/coco_object.py
  54. 58 0
      segmentation/datasets/coco_stuff.py
  55. 26 0
      segmentation/datasets/pascal_context.py
  56. 25 0
      segmentation/datasets/pascal_voc.py
  57. 22 0
      segmentation/evaluation/__init__.py
  58. 120 0
      segmentation/evaluation/builder.py
  59. 209 0
      segmentation/evaluation/group_palette.txt
  60. 378 0
      segmentation/evaluation/group_vit_seg.py
  61. 32 0
      setup.cfg
  62. 5 0
      tools/debug.sh
  63. 3 0
      tools/run.sh
  64. 4 0
      tools/run_slurm.sh
  65. 4 0
      tools/run_slurm_stage2.sh
  66. 3 0
      tools/run_stage2.sh
  67. 5 0
      tools/test_ade20k.sh
  68. 5 0
      tools/test_coco.sh
  69. 5 0
      tools/test_context.sh
  70. 5 0
      tools/test_voc12.sh
  71. 28 0
      utils/__init__.py
  72. 186 0
      utils/checkpoint.py
  73. 79 0
      utils/config.py
  74. 61 0
      utils/logger.py
  75. 36 0
      utils/lr_scheduler.py
  76. 126 0
      utils/misc.py
  77. 72 0
      utils/optimizer.py

+ 4 - 0
.gitignore

@@ -0,0 +1,4 @@
+imagenet_info/
+*.pyc
+*.pth
+*.pt

+ 97 - 0
LICENSE

@@ -0,0 +1,97 @@
+Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All rights reserved.
+
+
+NVIDIA Source Code License for GroupViT: Semantic Segmentation Emerges from Text Supervision
+
+
+=======================================================================
+
+1. Definitions
+
+“Licensor” means any person or entity that distributes its Work.
+
+“Software” means the original work of authorship made available under 
+this License.
+
+“Work” means the Software and any additions to or derivative works of 
+the Software that are made available under this License.
+
+The terms “reproduce,” “reproduction,” “derivative works,” and 
+“distribution” have the meaning as provided under U.S. copyright law; 
+provided, however, that for the purposes of this License, derivative 
+works shall not include works that remain separable from, or merely 
+link (or bind by name) to the interfaces of, the Work.
+
+Works, including the Software, are “made available” under this License 
+by including in or with the Work either (a) a copyright notice 
+referencing the applicability of this License to the Work, or (b) a 
+copy of this License.
+
+2. License Grants
+
+    2.1 Copyright Grant. Subject to the terms and conditions of this
+    License, each Licensor grants to you a perpetual, worldwide,
+    non-exclusive, royalty-free, copyright license to reproduce,
+    prepare derivative works of, publicly display, publicly perform,
+    sublicense and distribute its Work and any resulting derivative
+    works in any form.
+
+3. Limitations
+
+    3.1 Redistribution. You may reproduce or distribute the Work only
+    if (a) you do so under this License, (b) you include a complete
+    copy of this License with your distribution, and (c) you retain
+    without modification any copyright, patent, trademark, or
+    attribution notices that are present in the Work.
+
+    3.2 Derivative Works. You may specify that additional or different
+    terms apply to the use, reproduction, and distribution of your
+    derivative works of the Work ("Your Terms") only if (a) Your Terms
+    provide that the use limitation in Section 3.3 applies to your
+    derivative works, and (b) you identify the specific derivative
+    works that are subject to Your Terms. Notwithstanding Your Terms,
+    this License (including the redistribution requirements in Section
+    3.1) will continue to apply to the Work itself.
+
+    3.3 Use Limitation. The Work and any derivative works thereof only
+    may be used or intended for use non-commercially. Notwithstanding
+    the foregoing, NVIDIA and its affiliates may use the Work and any
+    derivative works commercially. As used herein, "non-commercially"
+    means for research or evaluation purposes only.
+
+    3.4 Patent Claims. If you bring or threaten to bring a patent claim
+    against any Licensor (including any claim, cross-claim or
+    counterclaim in a lawsuit) to enforce any patents that you allege
+    are infringed by any Work, then your rights under this License from
+    such Licensor (including the grant in Section 2.1) will terminate
+    immediately.
+
+    3.5 Trademarks. This License does not grant any rights to use any
+    Licensor’s or its affiliates’ names, logos, or trademarks, except
+    as necessary to reproduce the notices described in this License.
+
+    3.6 Termination. If you violate any term of this License, then your
+    rights under this License (including the grant in Section 2.1) will
+    terminate immediately.
+
+4. Disclaimer of Warranty.
+
+THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
+NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
+THIS LICENSE.
+
+5. Limitation of Liability.
+
+EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
+THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
+SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
+INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
+OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
+LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
+COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
+THE POSSIBILITY OF SUCH DAMAGES.
+
+=======================================================================

+ 109 - 2
README.md

@@ -1,2 +1,109 @@
-# OVSegmentor
-Comming soon
+# Learning Open-Vocabulary Semantic Segmentation Models From Natural Language Supervision
+
+This repository is the official implementation of [Learning Open-Vocabulary Semantic Segmentation Models From Natural Language Supervision](https://arxiv.org/abs/2301.09121) at CVPR 2023. Our transformer-based model, termed as OVSegmentor, is pre-trained on image-text pairs without using any mask annotations. After training, it can segment objects of arbitrary categories via zero-shot transfer. 
+
+
+<div align="center">
+<img src="figs/model.png" width="100%">
+</div>
+
+## Requirements
+* Python 3.9
+* [torch=1.11.0+cu113](https://pytorch.org/)
+* [torchvision=0.14.1](https://pytorch.org/)
+* [apex=0.1](https://github.com/NVIDIA/apex)
+* [mmcv-full=1.3.14](https://github.com/open-mmlab/mmcv)
+* [mmsegmentation=0.18.0](https://github.com/open-mmlab/mmsegmentation)
+* [clip=1.0](https://github.com/openai/CLIP)
+
+We recommand installing apex with cuda and c++ extensions
+
+To install the other requirements:
+
+```setup
+pip install -r requirements.txt
+```
+
+## Prepare datasets
+For training, we construct CC4M by filtering CC12M with a total number of 100 frequently appearred entities. The researchers are encouraged to prepare CC12M dataset from the [source](https://github.com/google-research-datasets/conceptual-12m) or using [img2dataset](https://github.com/rom1504/img2dataset). Note that, some url links may not be available any longer. The file structure should follow:
+
+```shell
+CC12M
+├── 000002a0c848e78c7b9d53584e2d36ab0ac14785.jpg
+├── 000002ca5e5eab763d95fa8ac0df7a11f24519e5.jpg
+├── 00000440ca9fe337152041e26c37f619ec4c55b2.jpg
+...
+```
+We provide the meta-file for CC4M at [here](https://drive.google.com/file/d/1ENpsWndAkWc0UZJvdJJDicPxpzrPugve/view?usp=share_link) for data loading. One may also try different [image-caption datasets](https://github.com/rom1504/img2dataset) (e.g. YFCC, RedCaps) by providing the images and the corresponding meta-file. The meta-file is a json file containing each filename and its caption in a single line.
+```shell
+{"filename": "000002ca5e5eab763d95fa8ac0df7a11f24519e5.jpg", "caption": "A man's hand holds an orange pencil on white"}
+{"filename": "000009b46e38a28790f481f36366c781e03e4bbd.jpg", "caption": "Cooking is chemistry, except you CAN lick the spoon!"}
+...
+```
+For evaluation, please follow the official websites to prepare [PASCAL VOC](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#pascal-voc), [PASCAL Context](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#pascal-context), [COCO](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#coco-stuff-164k) converted to semantic seg format following [GroupViT](https://github.com/NVlabs/GroupViT), and [ADE20K](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#ade20k). Remember to change the image dirs in segmentation/configs/_base_/datasets/*.py.
+
+To enable zero-shot classification evaluation, please prepare the validation set of [ImageNet](https://www.image-net.org/) with its corresponding meta-file. 
+
+## Other preparations
+1. The visual encoder is initialised with [DINO](https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth). Edit the checkpoint path in the config file.
+2. Pre-trained BERT model and nltk_data should be downloaded automatically.
+
+## Training
+To train the model(s) in the paper, we separate the training process as a two-stage pipeline. The first stage is a 30-epoch training with image-caption contrastive loss and masked entity completion loss, and the second-stage 10-epoch training further adds the cross-image mask consistency loss. 
+
+For the first stage training on a single node with 8 A100 (80G) GPUs, we recommand to use slurm script to enable training:
+
+```train
+cd OVSegmentor
+./tools/run_slurm.sh
+```
+Or simply use torch.distributed.launch as:
+
+```train
+./tools/run.sh
+```
+
+After that, please specify the checkpoint path from the 1st stage training in the config file used in the 2nd stage training (e.g. configs/ovsegmentor/ovsegmentor_pretrain_vit_bert_stage2.yml). During cross-image sampling, we sample another image that share the same entity with the current image. This is achieved by (1) identifying the visual entity for the image. (2) Perform sampling over the valid candidates. We offer the pre-processed [class_label.json](https://drive.google.com/file/d/15s0Pwn11bkB-RqGmpzf7z6lYPOd1sIZF/view?usp=share_link) and [here](https://drive.google.com/file/d/10sA94ZawsgL0E01im9-5xZciWnsCZOQz/view?usp=share_link).
+
+We also provide our pre-trained 1st stage checkpoint from [here](https://drive.google.com/file/d/19Kpeh5iTgGSr5mzf4n0j5hqxGDgG-Wxi/view?usp=share_link).
+
+Then, perform the second stage training. 
+```train
+./tools/run_slurm_stage2.sh
+```
+We adjust a few hyperparameters in 2nd stage to stablize the training process.
+
+## Evaluation
+
+To evaluate the model on PASCAL VOC, please specify the checkpoint path in tools/test_voc12.sh, and run:
+
+```eval
+./tools/test_voc12.sh
+```
+For PASCAL Context, COCO Object, and ADE20K, please refer to tools/.
+
+The performance may vary 3%~4% due to different cross-image sampling. 
+
+## Model Zoo
+
+The pre-trained models can be downloaded from here:
+
+| Model name  | Visual enc | Text enc      | Group tokens  | PASCAL VOC  | PASCAL Context | COCO Object | ADE20K | Checkpoint |
+| ------------------ |------------------ |------------------ |---------------- | -------------- | -------------- | -------------- | -------------- | -------------- |
+| OVSegmentor    | ViT-B|  BERT-Base|   8         |      53.8 |20.4       |  25.1         |      5.6       |    [download](https://drive.google.com/file/d/10F3b3FNzPdDx8LuKdjc1BzbSLMrPLvnc/view?usp=share_link)       |
+| OVSegmentor |ViT-S | Roberta-Base   |     8         |     44.5| 18.3       | 19.0         |      4.3       |   [download](https://drive.google.com/file/d/10F3b3FNzPdDx8LuKdjc1BzbSLMrPLvnc/view?usp=share_link)      |
+| OVSegmentor    | ViT-B|  BERT-Base|   16         |      Todo | Todo       | Todo         |      Todo       |   Todo       |
+
+## Citation
+If this work is helpful for your research, please consider citing us.
+```
+@article{xu2023learning,
+  title={Learning Open-vocabulary Semantic Segmentation Models From Natural Language Supervision},
+  author={Xu, Jilan and Hou, Junlin and Zhang, Yuejie and Feng, Rui and Wang, Yi and Qiao, Yu and Xie, Weidi},
+  journal={arXiv preprint arXiv:2301.09121},
+  year={2023}
+}
+```
+
+## Acknowledgements
+This project is built upon [GroupViT](https://github.com/NVlabs/GroupViT). Thanks to the contributors of the great codebase.

+ 115 - 0
configs/default.yml

@@ -0,0 +1,115 @@
+data:
+  batch_size: 256
+  pin_memory: true
+  num_workers: 10
+  # Thomas said it should be at least about 5-10x your batch size; beyond that,
+  # the differences become academic.
+  shuffle_buffer: 10000
+  seed: ${train.seed}
+  dataset:
+    meta:
+      gcc3m:
+        type: img_txt_pair
+        path: local_data/gcc3m_shards
+        prefix: gcc-train-{000000..00436}.tar
+        length: 2891445
+      gcc12m:
+        type: img_txt_pair
+        path: local_data/gcc12m_shards
+        prefix: gcc-conceptual-12m-{000000..001943}.tar
+        length: 11156203
+      yfcc14m:
+        type: img_txt_pair
+        path: local_data/yfcc14m_shards
+        prefix: yfcc14m-{000000..001888}.tar
+        length: 14615499
+      redcap12m:
+        type: img_txt_pair
+        path: local_data/redcap12m_shards
+        prefix: redcap12m-{000000..001211}.tar
+        length: 11866987
+      imagenet:
+        type: img_cls_pair
+        path: local_data/imagenet_shards
+        prefix: imagenet-val-{000000..000049}.tar
+        length: 50000
+    train:
+      - gcc3m
+      - gcc12m
+      - yfcc14m
+    val:
+      - imagenet
+
+  img_aug:
+    deit_aug: true
+    img_size: 224
+    img_scale: [0.08, 1.0]
+    interpolation: bilinear
+    color_jitter: 0.4
+    auto_augment: 'rand-m9-mstd0.5-inc1'
+    re_prob: 0.25
+    re_mode: 'pixel'
+    re_count: 1
+  text_aug:
+    max_seq_len: 77
+    multi_label: 0
+    word_type: 'noun'
+
+train:
+  start_epoch: 0
+  epochs: 30
+  warmup_epochs: 2
+  base_lr: 1.6e-3
+  weight_decay: 0.05
+  warmup_lr: 4e-6
+  min_lr: 4e-5
+  clip_grad: 5.0
+  accumulation_steps: 0
+  amp_opt_level: O1
+  seed: 0
+  use_entity: False
+
+  lr_scheduler:
+    name: cosine
+
+  optimizer:
+    name: adamw
+    eps: 1e-8
+    betas: [0.9, 0.999]
+
+evaluate:
+  eval_only: false
+  eval_freq: 1
+  task:
+    - cls
+    - seg
+  cls:
+    save_best: true
+    template: subset
+  seg:
+    save_best: true
+    cfg: segmentation/configs/_base_/datasets/pascal_voc12.py
+    template: simple
+    opts: []
+
+checkpoint:
+  auto_resume: true
+  resume: ''
+  stage1_checkpoint: '' ## add this for stage2 training
+  freq: 1
+  max_kept: -1
+  save_freq: 1
+
+model:
+  use_maskloss: false
+  use_entityloss: false
+
+
+model_name: '' # display name in the logger
+output: ???
+tag: default
+print_freq: 10
+seed: 0
+wandb: false
+local_rank: ???
+vis: []

+ 114 - 0
configs/ovsegmentor/debug.yml

@@ -0,0 +1,114 @@
+_base_: '../default.yml'
+model_name: 'debug' # display name in the logger
+output: /mnt/petrelfs/xujilan/exps/
+
+print_freq: 100
+data:
+  with_dc: False
+  train: 
+      root_dir: [
+                  's3://GCC/GCC12m/',
+                ]
+      meta_file: [
+                  '/mnt/petrelfs/xujilan/data/cc12m_100/cc4m.json',
+                ]
+      read_from: petrel
+      use_dali: True
+      batch_size: 256
+      input_size: 224
+      test_resize: 256
+
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed_epoch
+      transforms:
+          type: STANDARD
+      fseek: True
+      use_ranked: False
+      
+      use_entity: ${model.use_entityloss}
+      mask_type: class
+      use_distilbert: True
+  
+  val:
+      type: clip
+      read_from: petrel
+      use_dali: True
+      batch_size: 64
+      num_workers: 4
+      pin_memory: False
+      input_size: 224
+      test_resize: 256
+      
+      root_dir: '/mnt/cache/share/images/val/'
+      meta_file: 'imagenet_info/val.json'
+      # you can change it to imagenet_info relative path, file already in gitlab
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed
+      transforms:
+          type: ONECROP
+      evaluator:
+          type: imagenet
+          kwargs:
+              topk: [1, 5]
+      label_texts_ensemble: 'prompt1'
+          
+  img_aug:
+    deit_aug: true
+    img_size: 224
+    img_scale: [0.08, 1.0]
+    interpolation: 'bilinear'
+    # interpolation: 2
+    color_jitter: 0.4
+    auto_augment: 'rand-m9-mstd0.5-inc1'
+    re_prob: 0.25
+    re_mode: 'pixel'
+    re_count: 1
+  text_aug:
+    max_seq_len: 77
+    multi_label: 0 # we do not use multi-label contrastive 
+    word_type: 'noun'
+
+model:
+  type: MultiLabelContrastive
+  img_encoder:
+    type: GroupViT
+    embed_dim: 768
+    num_heads: [8, 8]
+    embed_factors: [1, 1]
+    depths: [6, 6]
+    num_group_tokens: [64, 0]
+    num_output_groups: [8]
+    drop_rate: 0.0
+    drop_path_rate: 0.1
+    patch_norm: false
+    imgnet_pretrained: 'dino'
+    fixed: false
+    imgnet_pretrained_checkpoint: '/mnt/petrelfs/xujilan/checkpoints/dino_vitbase16_pretrain.pth'
+
+  text_encoder:
+    type: Bert
+    context_length: 77
+    width: 768
+    layers: 6
+    vocab_size: 49408
+    pretrained: true
+    fixed: false
+  contrast_temperature: 0.07
+  proj_num_layers: 2
+  output_dim: 256
+  multi_label: ${data.text_aug.multi_label}
+  use_entityloss: true
+  
+train:
+  epochs: 30
+  base_lr: 6.4e-4
+  warmup_lr: 1.6e-5
+  min_lr: 1.6e-4
+checkpoint:
+  save_freq: 1
+evaluate:
+  eval_freq: 1

+ 114 - 0
configs/ovsegmentor/ovsegmentor_pretrain_vit_bert_stage1.yml

@@ -0,0 +1,114 @@
+_base_: '../default.yml'
+model_name: 'ovsegmentor_pretrain_vit_bert_cc4m_stage1' # display name in the logger
+output: /mnt/petrelfs/xujilan/exps/final_exps/
+
+print_freq: 100
+data:
+  with_dc: False
+  train: 
+      root_dir: [
+                  's3://GCC/GCC12m/',
+                ]
+      meta_file: [
+                  '/mnt/petrelfs/xujilan/data/cc12m_100/cc4m.json',
+                ]
+      read_from: petrel
+      use_dali: True
+      batch_size: 256
+      input_size: 224
+      test_resize: 256
+
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed_epoch
+      transforms:
+          type: STANDARD
+      fseek: True
+      use_ranked: False
+      
+      use_entity: ${model.use_entityloss}
+      mask_type: class
+      use_distilbert: True
+  
+  val:
+      type: clip
+      read_from: petrel
+      use_dali: True
+      batch_size: 64
+      num_workers: 4
+      pin_memory: False
+      input_size: 224
+      test_resize: 256
+      
+      root_dir: '/mnt/cache/share/images/val/'
+      meta_file: 'imagenet_info/val.json'
+      # you can change it to imagenet_info relative path, file already in gitlab
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed
+      transforms:
+          type: ONECROP
+      evaluator:
+          type: imagenet
+          kwargs:
+              topk: [1, 5]
+      label_texts_ensemble: 'prompt1'
+          
+  img_aug:
+    deit_aug: true
+    img_size: 224
+    img_scale: [0.08, 1.0]
+    interpolation: 'bilinear'
+    # interpolation: 2
+    color_jitter: 0.4
+    auto_augment: 'rand-m9-mstd0.5-inc1'
+    re_prob: 0.25
+    re_mode: 'pixel'
+    re_count: 1
+  text_aug:
+    max_seq_len: 77
+    multi_label: 0 # we do not use multi-label contrastive 
+    word_type: 'noun'
+
+model:
+  type: MultiLabelContrastive
+  img_encoder:
+    type: GroupViT
+    embed_dim: 768
+    num_heads: [8, 8]
+    embed_factors: [1, 1]
+    depths: [6, 6]
+    num_group_tokens: [64, 0]
+    num_output_groups: [8]
+    drop_rate: 0.0
+    drop_path_rate: 0.1
+    patch_norm: false
+    imgnet_pretrained: 'dino'
+    fixed: false
+    imgnet_pretrained_checkpoint: '/mnt/petrelfs/xujilan/checkpoints/dino_vitbase16_pretrain.pth'
+
+  text_encoder:
+    type: Bert
+    context_length: 77
+    width: 768
+    layers: 6
+    vocab_size: 49408
+    pretrained: true
+    fixed: false
+  contrast_temperature: 0.07
+  proj_num_layers: 2
+  output_dim: 256
+  multi_label: ${data.text_aug.multi_label}
+  use_entityloss: true
+  
+train:
+  epochs: 30
+  base_lr: 6.4e-4
+  warmup_lr: 1.6e-5
+  min_lr: 1.6e-4
+checkpoint:
+  save_freq: 1
+evaluate:
+  eval_freq: 1

+ 123 - 0
configs/ovsegmentor/ovsegmentor_pretrain_vit_bert_stage2.yml

@@ -0,0 +1,123 @@
+_base_: '../default.yml'
+model_name: 'ovsegmentor_pretrain_vit_bert_cc4m_stage2' # display name in the logger
+output: /mnt/petrelfs/xujilan/exps/final_exps/
+
+print_freq: 100
+data:
+  with_dc: False
+  train: 
+      root_dir: [
+                  's3://GCC/GCC12m/',
+                  ]
+      meta_file: [
+                  '/mnt/petrelfs/xujilan/data/cc12m_100/cc4m.json',
+                  ]
+      read_from: petrel
+      use_dali: True
+      batch_size: 128
+      input_size: 224
+      test_resize: 256
+
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed_epoch
+      transforms:
+          type: STANDARD
+      fseek: True
+      use_ranked: False
+      
+      ### for entity loss ###
+      use_entity: ${model.use_entityloss}
+      mask_type: class
+      use_distilbert: True
+
+      ### for mask loss ### 
+      cross_image: ${model.use_maskloss}
+      class_label_dir: '/mnt/petrelfs/xujilan/data/cc12m_100/class_label.json'
+      sample_list_dir: '/mnt/petrelfs/xujilan/data/cc12m_100/sample_list.json'
+
+  val:
+      type: clip
+      read_from: petrel
+      use_dali: True
+      batch_size: 64
+      num_workers: 4
+      pin_memory: False
+      input_size: 224
+      test_resize: 256
+      
+      root_dir: '/mnt/cache/share/images/val/'
+      meta_file: 'imagenet_info/val.json'
+      # you can change it to imagenet_info relative path, file already in gitlab
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed
+      transforms:
+          type: ONECROP
+      evaluator:
+          type: imagenet
+          kwargs:
+              topk: [1, 5]
+      label_texts_ensemble: 'prompt1'
+          
+  img_aug:
+    deit_aug: false
+    img_size: 224
+    img_scale: [0.4, 1.0]
+    interpolation: 'bilinear'
+    color_jitter: 0.4
+    auto_augment: 'rand-m9-mstd0.5-inc1'
+    re_prob: 0.25
+    re_mode: 'pixel'
+    re_count: 1
+  text_aug:
+    max_seq_len: 77
+    multi_label: 0 #changed to singlelabel
+    word_type: 'noun'
+
+model:
+  type: MultiLabelContrastive
+  img_encoder:
+    type: GroupViT
+    embed_dim: 768
+    num_heads: [8, 8]
+    embed_factors: [1, 1]
+    depths: [6, 6]
+    num_group_tokens: [64, 0]
+    num_output_groups: [8]
+    drop_rate: 0.0
+    drop_path_rate: 0.1
+    patch_norm: false
+    imgnet_pretrained: 'dino'
+    fixed: false
+    imgnet_pretrained_checkpoint: '/mnt/petrelfs/xujilan/checkpoints/dino_vitbase16_pretrain.pth'
+
+  text_encoder:
+    type: Bert
+    context_length: 77
+    width: 768
+    layers: 6
+    vocab_size: 49408
+    pretrained: true
+    fixed: true
+  contrast_temperature: 0.07
+  proj_num_layers: 2
+  output_dim: 256
+  multi_label: ${data.text_aug.multi_label}
+
+  use_entityloss: true
+  use_maskloss: true
+  cross_threshold: 0.6
+  
+train:
+  epochs: 10
+  base_lr: 1e-5
+  min_lr: 1e-6
+  warmup_epochs: 0
+checkpoint:
+  save_freq: 1
+  stage1_checkpoint: /mnt/petrelfs/xujilan/exps/best_miou.pth
+evaluate:
+  eval_freq: 1

+ 121 - 0
configs/test_ade20k.yml

@@ -0,0 +1,121 @@
+_base_: 'default.yml'
+model_name: 'test_ade20k' # display name in the logger
+output: /mnt/petrelfs/xujilan/exps/cc12m_100/
+
+print_freq: 100
+data:
+  with_dc: False
+  train: 
+      root_dir: [
+                  's3://GCC/GCC12m/',
+                  ]
+      meta_file: [
+                  '/mnt/cache/share_data/DSK_datasets/cc12m/subset/cc12m_top100_coconouns.json',
+                  ]
+      read_from: petrel
+      use_dali: True
+      batch_size: 256
+      input_size: 224
+      test_resize: 256
+
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed_epoch
+      transforms:
+          type: STANDARD
+      fseek: True
+      use_ranked: False
+
+      ### for entity loss ###
+      use_entity: ${model.use_entityloss}
+      mask_type: class
+      use_distilbert: True
+
+      ### for mask loss ### 
+      cross_image: ${model.use_maskloss}
+      
+      
+  val:
+      type: clip
+      read_from: petrel
+      use_dali: True
+      batch_size: 64
+      num_workers: 4
+      pin_memory: False
+      input_size: 224
+      test_resize: 256
+      
+      root_dir: '/mnt/cache/share/images/val/'
+      meta_file: 'imagenet_info/val.json'
+      # you can change it to imagenet_info relative path, file already in gitlab
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed
+      transforms:
+          type: ONECROP
+      evaluator:
+          type: imagenet
+          kwargs:
+              topk: [1, 5]
+      label_texts_ensemble: 'prompt1'
+          
+img_aug:
+    deit_aug: false
+    img_size: 224
+    img_scale: [0.4, 1.0]
+    interpolation: 'bilinear'
+    color_jitter: 0.4
+    auto_augment: 'rand-m9-mstd0.5-inc1'
+    re_prob: 0.25
+    re_mode: 'pixel'
+    re_count: 1
+  text_aug:
+    max_seq_len: 77
+    multi_label: 0 #changed to singlelabel
+    word_type: 'noun'
+
+model:
+  type: MultiLabelContrastive
+  img_encoder:
+    type: GroupViT
+    embed_dim: 768
+    num_heads: [8, 8]
+    embed_factors: [1, 1]
+    depths: [6, 6]
+    num_group_tokens: [64, 0]
+    num_output_groups: [8]
+    drop_rate: 0.0
+    drop_path_rate: 0.1
+    patch_norm: false
+    imgnet_pretrained: 'dino'
+    fixed: false
+
+  text_encoder:
+    type: Bert
+    context_length: 77
+    width: 768
+    layers: 6
+    vocab_size: 49408
+    pretrained: true
+    fixed: true
+  contrast_temperature: 0.07
+  proj_num_layers: 2
+  output_dim: 256
+  multi_label: ${data.text_aug.multi_label}
+  
+  use_entityloss: true
+  use_maskloss: true
+  cross_threshold: 0.6
+  
+train:
+  epochs: 50
+  base_lr: 1.6e-4
+checkpoint:
+  save_freq: 1
+evaluate:
+  eval_freq: 1
+  seg:
+    cfg: segmentation/configs/_base_/datasets/ade20k.py
+

+ 123 - 0
configs/test_coco.yml

@@ -0,0 +1,123 @@
+_base_: 'default.yml'
+model_name: 'test_coco' # display name in the logger
+output: /mnt/petrelfs/xujilan/exps/cc12m_100/
+
+print_freq: 100
+data:
+  with_dc: False
+  train: 
+      root_dir: [
+                  's3://GCC/GCC12m/',
+                  ]
+      meta_file: [
+                  '/mnt/cache/share_data/DSK_datasets/cc12m/subset/cc12m_top100_coconouns.json',
+                  ]
+      read_from: petrel
+      use_dali: True
+      batch_size: 256
+      input_size: 224
+      test_resize: 256
+
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed_epoch
+      transforms:
+          type: STANDARD
+      fseek: True
+      use_ranked: False
+
+      ### for entity loss ###
+      use_entity: ${model.use_entityloss}
+      mask_type: class
+      use_distilbert: True
+
+      ### for mask loss ### 
+      cross_image: ${model.use_maskloss}
+      
+      
+  val:
+      type: clip
+      read_from: petrel
+      use_dali: True
+      batch_size: 64
+      num_workers: 4
+      pin_memory: False
+      input_size: 224
+      test_resize: 256
+      
+      root_dir: '/mnt/cache/share/images/val/'
+      meta_file: 'imagenet_info/val.json'
+      # you can change it to imagenet_info relative path, file already in gitlab
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed
+      transforms:
+          type: ONECROP
+      evaluator:
+          type: imagenet
+          kwargs:
+              topk: [1, 5]
+      label_texts_ensemble: 'prompt1'
+          
+img_aug:
+    deit_aug: false
+    img_size: 224
+    img_scale: [0.4, 1.0]
+    interpolation: 'bilinear'
+    color_jitter: 0.4
+    auto_augment: 'rand-m9-mstd0.5-inc1'
+    re_prob: 0.25
+    re_mode: 'pixel'
+    re_count: 1
+  text_aug:
+    max_seq_len: 77
+    multi_label: 0 #changed to singlelabel
+    word_type: 'noun'
+
+model:
+  type: MultiLabelContrastive
+  img_encoder:
+    type: GroupViT
+    embed_dim: 768
+    num_heads: [8, 8]
+    embed_factors: [1, 1]
+    depths: [6, 6]
+    num_group_tokens: [64, 0]
+    num_output_groups: [8]
+    drop_rate: 0.0
+    drop_path_rate: 0.1
+    patch_norm: false
+    imgnet_pretrained: 'dino'
+    fixed: false
+
+  text_encoder:
+    type: Bert
+    context_length: 77
+    width: 768
+    layers: 6
+    vocab_size: 49408
+    pretrained: true
+    fixed: true
+  contrast_temperature: 0.07
+  proj_num_layers: 2
+  output_dim: 256
+  multi_label: ${data.text_aug.multi_label}
+  
+  use_entityloss: true
+  use_maskloss: true
+  cross_threshold: 0.6
+  
+train:
+  epochs: 50
+  base_lr: 1.6e-4
+checkpoint:
+  save_freq: 1
+evaluate:
+  eval_freq: 1
+  seg:
+    cfg: segmentation/configs/_base_/datasets/coco.py
+
+
+# vis: ['input_pred']

+ 121 - 0
configs/test_coco_stuff.yml

@@ -0,0 +1,121 @@
+_base_: 'default.yml'
+model_name: 'test_cocostuff' # display name in the logger
+output: /mnt/petrelfs/xujilan/exps/cc12m_100/
+
+print_freq: 100
+data:
+  with_dc: False
+  train: 
+      root_dir: [
+                  's3://GCC/GCC12m/',
+                  ]
+      meta_file: [
+                  '/mnt/cache/share_data/DSK_datasets/cc12m/subset/cc12m_top100_coconouns.json',
+                  ]
+      read_from: petrel
+      use_dali: True
+      batch_size: 256
+      input_size: 224
+      test_resize: 256
+
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed_epoch
+      transforms:
+          type: STANDARD
+      fseek: True
+      use_ranked: False
+
+      ### for entity loss ###
+      use_entity: ${model.use_entityloss}
+      mask_type: class
+      use_distilbert: True
+
+      ### for mask loss ### 
+      cross_image: ${model.use_maskloss}
+      
+      
+  val:
+      type: clip
+      read_from: petrel
+      use_dali: True
+      batch_size: 64
+      num_workers: 4
+      pin_memory: False
+      input_size: 224
+      test_resize: 256
+      
+      root_dir: '/mnt/cache/share/images/val/'
+      meta_file: 'imagenet_info/val.json'
+      # you can change it to imagenet_info relative path, file already in gitlab
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed
+      transforms:
+          type: ONECROP
+      evaluator:
+          type: imagenet
+          kwargs:
+              topk: [1, 5]
+      label_texts_ensemble: 'prompt1'
+          
+img_aug:
+    deit_aug: false
+    img_size: 224
+    img_scale: [0.4, 1.0]
+    interpolation: 'bilinear'
+    color_jitter: 0.4
+    auto_augment: 'rand-m9-mstd0.5-inc1'
+    re_prob: 0.25
+    re_mode: 'pixel'
+    re_count: 1
+  text_aug:
+    max_seq_len: 77
+    multi_label: 0 #changed to singlelabel
+    word_type: 'noun'
+
+model:
+  type: MultiLabelContrastive
+  img_encoder:
+    type: GroupViT
+    embed_dim: 768
+    num_heads: [8, 8]
+    embed_factors: [1, 1]
+    depths: [6, 6]
+    num_group_tokens: [64, 0]
+    num_output_groups: [8]
+    drop_rate: 0.0
+    drop_path_rate: 0.1
+    patch_norm: false
+    imgnet_pretrained: 'dino'
+    fixed: false
+
+  text_encoder:
+    type: Bert
+    context_length: 77
+    width: 768
+    layers: 6
+    vocab_size: 49408
+    pretrained: true
+    fixed: true
+  contrast_temperature: 0.07
+  proj_num_layers: 2
+  output_dim: 256
+  multi_label: ${data.text_aug.multi_label}
+  
+  use_entityloss: true
+  use_maskloss: true
+  cross_threshold: 0.6
+  
+train:
+  epochs: 50
+  base_lr: 1.6e-4
+checkpoint:
+  save_freq: 1
+evaluate:
+  eval_freq: 1
+  seg:
+    cfg: segmentation/configs/_base_/datasets/coco_stuff.py
+

+ 118 - 0
configs/test_voc12.yml

@@ -0,0 +1,118 @@
+_base_: 'default.yml'
+model_name: 'test_voc12' # display name in the logger
+output: /mnt/petrelfs/xujilan/exps/cc12m_100/
+
+print_freq: 100
+data:
+  with_dc: False
+  train: 
+      root_dir: [
+                  's3://GCC/GCC12m/',
+                  ]
+      meta_file: [
+                  '/mnt/cache/share_data/DSK_datasets/cc12m/subset/cc12m_top100_coconouns.json',
+                  ]
+      read_from: petrel
+      use_dali: True
+      batch_size: 256
+      input_size: 224
+      test_resize: 256
+
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed_epoch
+      transforms:
+          type: STANDARD
+      fseek: True
+      use_ranked: False
+
+      ### for entity loss ###
+      use_entity: ${model.use_entityloss}
+      mask_type: class
+      use_distilbert: True
+
+      ### for mask loss ### 
+      cross_image: ${model.use_maskloss}
+      
+      
+  val:
+      type: clip
+      read_from: petrel
+      use_dali: True
+      batch_size: 64
+      num_workers: 4
+      pin_memory: False
+      input_size: 224
+      test_resize: 256
+      
+      root_dir: '/mnt/cache/share/images/val/'
+      meta_file: 'imagenet_info/val.json'
+      # you can change it to imagenet_info relative path, file already in gitlab
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed
+      transforms:
+          type: ONECROP
+      evaluator:
+          type: imagenet
+          kwargs:
+              topk: [1, 5]
+      label_texts_ensemble: 'prompt1'
+          
+img_aug:
+    deit_aug: false
+    img_size: 224
+    img_scale: [0.4, 1.0]
+    interpolation: 'bilinear'
+    color_jitter: 0.4
+    auto_augment: 'rand-m9-mstd0.5-inc1'
+    re_prob: 0.25
+    re_mode: 'pixel'
+    re_count: 1
+  text_aug:
+    max_seq_len: 77
+    multi_label: 0 #changed to singlelabel
+    word_type: 'noun'
+
+model:
+  type: MultiLabelContrastive
+  img_encoder:
+    type: GroupViT
+    embed_dim: 768
+    num_heads: [8, 8]
+    embed_factors: [1, 1]
+    depths: [6, 6]
+    num_group_tokens: [64, 0]
+    num_output_groups: [8]
+    drop_rate: 0.0
+    drop_path_rate: 0.1
+    patch_norm: false
+    imgnet_pretrained: 'dino'
+    fixed: false
+
+  text_encoder:
+    type: Bert
+    context_length: 77
+    width: 768
+    layers: 6
+    vocab_size: 49408
+    pretrained: true
+    fixed: true
+  contrast_temperature: 0.07
+  proj_num_layers: 2
+  output_dim: 256
+  multi_label: ${data.text_aug.multi_label}
+  
+  use_entityloss: true
+  use_maskloss: true
+  cross_threshold: 0.6
+  
+train:
+  epochs: 50
+  base_lr: 1.6e-4
+checkpoint:
+  save_freq: 1
+evaluate:
+  eval_freq: 1

+ 123 - 0
configs/test_voc_context.yml

@@ -0,0 +1,123 @@
+_base_: 'default.yml'
+model_name: 'test_context' # display name in the logger
+output: /mnt/petrelfs/xujilan/exps/cc12m_100/
+
+print_freq: 100
+data:
+  with_dc: False
+  train: 
+      root_dir: [
+                  's3://GCC/GCC12m/',
+                  ]
+      meta_file: [
+                  '/mnt/cache/share_data/DSK_datasets/cc12m/subset/cc12m_top100_coconouns.json',
+                  ]
+      read_from: petrel
+      use_dali: True
+      batch_size: 256
+      input_size: 224
+      test_resize: 256
+
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed_epoch
+      transforms:
+          type: STANDARD
+      fseek: True
+      use_ranked: False
+
+      ### for entity loss ###
+      use_entity: ${model.use_entityloss}
+      mask_type: class
+      use_distilbert: True
+
+      ### for mask loss ### 
+      cross_image: ${model.use_maskloss}
+      
+      
+  val:
+      type: clip
+      read_from: petrel
+      use_dali: True
+      batch_size: 64
+      num_workers: 4
+      pin_memory: False
+      input_size: 224
+      test_resize: 256
+      
+      root_dir: '/mnt/cache/share/images/val/'
+      meta_file: 'imagenet_info/val.json'
+      # you can change it to imagenet_info relative path, file already in gitlab
+      image_reader:
+          type: pil
+      sampler:
+          type: distributed
+      transforms:
+          type: ONECROP
+      evaluator:
+          type: imagenet
+          kwargs:
+              topk: [1, 5]
+      label_texts_ensemble: 'prompt1'
+          
+img_aug:
+    deit_aug: false
+    img_size: 224
+    img_scale: [0.4, 1.0]
+    interpolation: 'bilinear'
+    color_jitter: 0.4
+    auto_augment: 'rand-m9-mstd0.5-inc1'
+    re_prob: 0.25
+    re_mode: 'pixel'
+    re_count: 1
+  text_aug:
+    max_seq_len: 77
+    multi_label: 0 #changed to singlelabel
+    word_type: 'noun'
+
+model:
+  type: MultiLabelContrastive
+  img_encoder:
+    type: GroupViT
+    embed_dim: 768
+    num_heads: [8, 8]
+    embed_factors: [1, 1]
+    depths: [6, 6]
+    num_group_tokens: [64, 0]
+    num_output_groups: [8]
+    drop_rate: 0.0
+    drop_path_rate: 0.1
+    patch_norm: false
+    imgnet_pretrained: 'dino'
+    fixed: false
+
+  text_encoder:
+    type: Bert
+    context_length: 77
+    width: 768
+    layers: 6
+    vocab_size: 49408
+    pretrained: true
+    fixed: true
+  contrast_temperature: 0.07
+  proj_num_layers: 2
+  output_dim: 256
+  multi_label: ${data.text_aug.multi_label}
+  
+  use_entityloss: true
+  use_maskloss: true
+  cross_threshold: 0.6
+  
+train:
+  epochs: 50
+  base_lr: 1.6e-4
+checkpoint:
+  save_freq: 1
+evaluate:
+  eval_freq: 1
+  seg:
+    cfg: segmentation/configs/_base_/datasets/pascal_context.py
+
+
+# vis: ['input_pred']

+ 288 - 0
convert_dataset/convert_coco_object.py

@@ -0,0 +1,288 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import argparse
+import os.path as osp
+import shutil
+from functools import partial
+from glob import glob
+
+import mmcv
+import numpy as np
+from PIL import Image
+
+COCO_LEN = 123287
+
+clsID_to_trID = {
+    0: 0,
+    1: 1,
+    2: 2,
+    3: 3,
+    4: 4,
+    5: 5,
+    6: 6,
+    7: 7,
+    8: 8,
+    9: 9,
+    10: 10,
+    12: 11,
+    13: 12,
+    14: 13,
+    15: 14,
+    16: 15,
+    17: 16,
+    18: 17,
+    19: 18,
+    20: 19,
+    21: 20,
+    22: 21,
+    23: 22,
+    24: 23,
+    26: 24,
+    27: 25,
+    30: 26,
+    31: 27,
+    32: 28,
+    33: 29,
+    34: 30,
+    35: 31,
+    36: 32,
+    37: 33,
+    38: 34,
+    39: 35,
+    40: 36,
+    41: 37,
+    42: 38,
+    43: 39,
+    45: 40,
+    46: 41,
+    47: 42,
+    48: 43,
+    49: 44,
+    50: 45,
+    51: 46,
+    52: 47,
+    53: 48,
+    54: 49,
+    55: 50,
+    56: 51,
+    57: 52,
+    58: 53,
+    59: 54,
+    60: 55,
+    61: 56,
+    62: 57,
+    63: 58,
+    64: 59,
+    66: 60,
+    69: 61,
+    71: 62,
+    72: 63,
+    73: 64,
+    74: 65,
+    75: 66,
+    76: 67,
+    77: 68,
+    78: 69,
+    79: 70,
+    80: 71,
+    81: 72,
+    83: 73,
+    84: 74,
+    85: 75,
+    86: 76,
+    87: 77,
+    88: 78,
+    89: 79,
+    91: 80,
+    92: 81,
+    93: 82,
+    94: 83,
+    95: 84,
+    96: 85,
+    97: 86,
+    98: 87,
+    99: 88,
+    100: 89,
+    101: 90,
+    102: 91,
+    103: 92,
+    104: 93,
+    105: 94,
+    106: 95,
+    107: 96,
+    108: 97,
+    109: 98,
+    110: 99,
+    111: 100,
+    112: 101,
+    113: 102,
+    114: 103,
+    115: 104,
+    116: 105,
+    117: 106,
+    118: 107,
+    119: 108,
+    120: 109,
+    121: 110,
+    122: 111,
+    123: 112,
+    124: 113,
+    125: 114,
+    126: 115,
+    127: 116,
+    128: 117,
+    129: 118,
+    130: 119,
+    131: 120,
+    132: 121,
+    133: 122,
+    134: 123,
+    135: 124,
+    136: 125,
+    137: 126,
+    138: 127,
+    139: 128,
+    140: 129,
+    141: 130,
+    142: 131,
+    143: 132,
+    144: 133,
+    145: 134,
+    146: 135,
+    147: 136,
+    148: 137,
+    149: 138,
+    150: 139,
+    151: 140,
+    152: 141,
+    153: 142,
+    154: 143,
+    155: 144,
+    156: 145,
+    157: 146,
+    158: 147,
+    159: 148,
+    160: 149,
+    161: 150,
+    162: 151,
+    163: 152,
+    164: 153,
+    165: 154,
+    166: 155,
+    167: 156,
+    168: 157,
+    169: 158,
+    170: 159,
+    171: 160,
+    172: 161,
+    173: 162,
+    174: 163,
+    175: 164,
+    176: 165,
+    177: 166,
+    178: 167,
+    179: 168,
+    180: 169,
+    181: 170,
+    255: 255
+}
+
+# set to background for cocoobject with first 80 foreground classes
+# for k, v in clsID_to_trID.items():
+#     clsID_to_trID[k] = v + 1
+#     if k > 90:
+#         clsID_to_trID[k] = 0
+
+# for all foreground classes
+for k, v in clsID_to_trID.items():
+    clsID_to_trID[k] = v + 1
+    if k == 255:
+        clsID_to_trID[k] = 0
+
+
+def convert_to_trainID(maskpath, out_mask_dir, is_train):
+    mask = np.array(Image.open(maskpath))
+    mask_copy = mask.copy()
+    for clsID, trID in clsID_to_trID.items():
+        mask_copy[mask == clsID] = trID
+    seg_filename = osp.join(
+        out_mask_dir, 'train2017',
+        osp.basename(maskpath).split('.')[0] +
+        '_instanceTrainIds.png') if is_train else osp.join(
+            out_mask_dir, 'val2017',
+            osp.basename(maskpath).split('.')[0] + '_instanceTrainIds.png')
+    Image.fromarray(mask_copy).save(seg_filename, 'PNG')
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description=\
+        'Convert COCO Stuff 164k annotations to COCO Objects')  # noqa
+    parser.add_argument('coco_path', help='coco stuff path')
+    parser.add_argument('-o', '--out_dir', help='output path')
+    parser.add_argument(
+        '--nproc', default=16, type=int, help='number of process')
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    args = parse_args()
+    coco_path = args.coco_path
+    nproc = args.nproc
+
+    out_dir = args.out_dir or coco_path
+    out_img_dir = osp.join(out_dir, 'images')
+    out_mask_dir = osp.join(out_dir, 'annotations')
+
+    mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017'))
+    mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017'))
+
+    if out_dir != coco_path:
+        shutil.copytree(osp.join(coco_path, 'images'), out_img_dir)
+
+    train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png'))
+    train_list = [file for file in train_list if 'TrainIds' not in file]
+    test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png'))
+    test_list = [file for file in test_list if 'TrainIds' not in file]
+    assert (len(train_list) +
+            len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
+                len(train_list), len(test_list))
+
+    if args.nproc > 1:
+        mmcv.track_parallel_progress(
+            partial(
+                convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
+            train_list,
+            nproc=nproc)
+        mmcv.track_parallel_progress(
+            partial(
+                convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
+            test_list,
+            nproc=nproc)
+    else:
+        mmcv.track_progress(
+            partial(
+                convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
+            train_list)
+        mmcv.track_progress(
+            partial(
+                convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
+            test_list)
+
+    print('Done!')
+
+
+if __name__ == '__main__':
+    main()

+ 119 - 0
convert_dataset/convert_yfcc14m.py

@@ -0,0 +1,119 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import argparse
+import json
+import os
+import os.path as osp
+import random
+import sys
+import zipfile
+
+import numpy as np
+import pandas as pd
+import webdataset as wds
+from tqdm import tqdm
+import mmcv
+
+def write_dataset(args):
+
+    df = pd.read_csv(
+        args.info, sep='\t', index_col='file', dtype=str, lineterminator='\n')
+    print(f'Loaded dataframe: \n{df}')
+    print(f'Length: \n{len(df)}')
+
+    # This is the output pattern under which we write shards.
+    pattern = os.path.join(args.shards, f'yfcc14m-%06d.tar')
+
+    with wds.ShardWriter(
+            pattern, maxsize=int(args.maxsize),
+            maxcount=int(args.maxcount)) as sink:
+        sink.verbose = 0
+        all_keys = set()
+
+        skipped = 0
+        zip_files = list(mmcv.scandir(args.root, suffix='zip'))
+        for idx, file in tqdm(
+                enumerate(zip_files), desc='total', total=len(zip_files)):
+            with zipfile.ZipFile(osp.join(args.root, file), 'r') as zfile:
+                filename_list = zfile.namelist()
+                for filename in tqdm(
+                        filename_list, position=1, desc=f'{file}', leave=None):
+                    image = zfile.read(filename)
+                    if image is None:
+                        skipped += 1
+                        tqdm.write(f'Skipping {filename}, {skipped}/{len(df)}')
+                        continue
+                    fname = filename.replace('data/images/', '')
+                    # Construct a unique key from the filename.
+                    key = os.path.splitext(os.path.basename(fname))[0]
+
+                    # Useful check.
+                    if key in all_keys:
+                        tqdm.write(f'duplicate: {fname}')
+                        continue
+                    assert key not in all_keys
+                    all_keys.add(key)
+
+                    text = str(df.loc[fname]['caption'])
+
+                    if len(text.split(' ')) < 2:
+                        skipped += 1
+                        tqdm.write(f'Text {text} too short')
+                        tqdm.write(f'Skipping {fname}, {skipped}/{len(df)}')
+                        continue
+
+                    # Construct a sample.
+                    xkey = key
+                    sample = {'__key__': xkey, 'jpg': image, 'text': text}
+
+                    # Write the sample to the sharded tar archives.
+                    sink.write(sample)
+        print(f'skipped: {skipped}/{len(df)}')
+        print(f'total keys: {len(all_keys)}')
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        """Generate sharded dataset from original ImageNet data.""")
+    parser.add_argument('--maxsize', type=float, default=1e9)
+    parser.add_argument('--maxcount', type=float, default=100000)
+    parser.add_argument('--shards', help='directory where shards are written')
+    parser.add_argument('--root', help='data root path')
+    parser.add_argument('--info', help='tsv path')
+    args = parser.parse_args()
+
+    assert args.maxsize > 10000000
+    assert args.maxcount < 1000000
+    return args
+
+
+def main():
+    args = parse_args()
+
+    seed = 0
+    random.seed(seed)
+    np.random.seed(seed)
+    os.environ['PYTHONHASHSEED'] = str(seed)
+
+    if not os.path.isdir(os.path.join(args.shards, '.')):
+        print(
+            f'{args.shards}: should be a writable destination directory for shards',
+            file=sys.stderr)
+        sys.exit(1)
+
+    write_dataset(args=args)
+
+
+if __name__ == '__main__':
+    main()

+ 114 - 0
convert_dataset/create_subset.py

@@ -0,0 +1,114 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import os
+import os.path as osp
+import argparse
+
+import pandas as pd
+import sqlite3
+import pandas as pd
+import os.path as osp
+from urllib.parse import unquote
+import re
+from datadings.tools import locate_files
+from yfcc100m.vars import FILES
+from yfcc100m.convert_metadata import download_db
+from pandarallel import pandarallel
+
+pandarallel.initialize(progress_bar=True)
+
+
+def key2path(key):
+    img_path = osp.join(key[0:3], key[3:6], key + '.jpg')
+    return img_path
+
+
+def clean_caption(line):
+    line = unquote(str(line))
+    line = remove_html_tags(line)
+    return line.replace('\n', ' ').replace('+', ' ')
+
+
+def remove_html_tags(text):
+    """Remove html tags from a string"""
+    clean = re.compile('<.*?>')
+    return re.sub(clean, '', text)
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Create YFCC subset sql db and tsv')
+    parser.add_argument('--input-dir', help='input sql db file directory')
+    parser.add_argument('--output-dir', help='output tsv directory')
+    parser.add_argument(
+        '--subset', help='subset of data to use', default='yfcc100m_subset_data.tsv')
+    args = parser.parse_args()
+
+    return args
+
+
+def main():
+    args = parse_args()
+
+    files = locate_files(FILES, args.input_dir)
+    # download DB file with AWS tools
+    download_db(files)
+
+    fullset_name = 'yfcc100m_dataset'
+    subset_name = 'yfcc14m_dataset'
+    conn = sqlite3.connect(osp.join(args.input_dir, 'yfcc100m_dataset.sql'))
+    # get column names
+    # some settings that hopefully speed up the queries
+    # conn.execute(f'PRAGMA query_only = YES')
+    conn.execute(f'PRAGMA journal_mode = OFF')
+    conn.execute(f'PRAGMA locking_mode = EXCLUSIVE')
+    conn.execute(f'PRAGMA page_size = 4096')
+    conn.execute(f'PRAGMA mmap_size = {4*1024*1024}')
+    conn.execute(f'PRAGMA cache_size = 10000')
+
+    print('reading subset data')
+    subset_df = pd.read_csv(args.subset, sep='\t', usecols=[1, 2], names=['photoid', 'photo_hash'], index_col='photoid')
+    subset_df.to_sql(subset_name, con=conn, if_exists='replace')
+
+    print('overwriting with subset')
+    select_query = f'select {fullset_name}.*, {subset_name}.photo_hash from {fullset_name} inner join {subset_name} on {fullset_name}.photoid = {subset_name}.photoid'
+    new_name = 'yfcc100m_dataset_new'
+    print('creating new table')
+    conn.execute(f'drop table if exists {new_name}')
+    conn.execute(' '.join([f'create table {new_name} as ', select_query]))
+    print(f'droping {fullset_name}')
+    conn.execute(f'drop table if exists {fullset_name}')
+    print(f'droping {subset_name}')
+    conn.execute(f'drop table if exists {subset_name}')
+    print(f'renaming {new_name} to {fullset_name}')
+    conn.execute(f'alter table {new_name} rename to {fullset_name}')
+    print('vacuuming db')
+    conn.execute('vacuum')
+
+    print(f'Loading dataframe from SQL')
+    anno_df = pd.read_sql(f'select * from {fullset_name}', con=conn)
+    print(f'Loaded dataframe from SQL: \n{anno_df.head()}')
+    print(f'Length: \n{len(anno_df)}')
+    print(f'generating filepath')
+    anno_df['file'] = anno_df['photo_hash'].parallel_map(key2path)
+    anno_df['caption'] = anno_df['description'].parallel_map(clean_caption)
+    anno_df = anno_df[['file', 'caption']]
+    print(f'Generated dataframe: \n{anno_df.head()}')
+
+    print('saving subset as tsv')
+    os.makedirs(args.output_dir, exist_ok=True)
+    anno_df.to_csv(osp.join(args.output_dir, 'yfcc14m_dataset.tsv'), sep='\t', index=False)
+    conn.close()
+
+
+if __name__ == '__main__':
+    main()

+ 71 - 0
convert_dataset/process_redcaps.py

@@ -0,0 +1,71 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import argparse
+import json
+import os
+import os.path as osp
+import random
+
+import pandas as pd
+import pyarrow as pa
+import pyarrow.parquet as pq
+import tqdm
+
+
+def get_args_parser():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        'input', type=str, help='path to redcaps annotations directory')
+    parser.add_argument(
+        'output', type=str, help='output annotations file path')
+    parser.add_argument(
+        '--num-split', type=int, help='number of splits to make')
+    return parser
+
+
+def main(args):
+    annos = []
+    for fname in tqdm.tqdm(os.listdir(args.input), desc='merging json files'):
+        if fname.endswith('json'):
+            with open(os.path.join(args.input, fname)) as f:
+                a = json.load(f)
+                for d in a['annotations']:
+                    cur_d = {'URL': d['url'], 'TEXT': d['caption']}
+                    annos.append(cur_d)
+
+    random.seed(42)
+    random.shuffle(annos)
+    if args.num_split is None:
+        df = pd.DataFrame(annos)
+        print(df.head())
+        print(f'saving {len(df)} annotations to {args.output}')
+        table = pa.Table.from_pandas(df)
+        os.makedirs(osp.dirname(args.output), exist_ok=True)
+        pq.write_table(table, args.output)
+    else:
+        for i in range(args.num_split):
+            df = pd.DataFrame(annos[i::args.num_split])
+            print(df.head())
+            output = osp.splitext(
+                args.output)[0] + f'_part{i}{osp.splitext(args.output)[1]}'
+            print(f'saving {len(df)} annotations to {output}')
+            table = pa.Table.from_pandas(df)
+            os.makedirs(osp.dirname(output), exist_ok=True)
+            pq.write_table(table, output)
+
+
+if __name__ == '__main__':
+    parser = get_args_parser()
+    args = parser.parse_args()
+    main(args)

+ 20 - 0
datasets/__init__.py

@@ -0,0 +1,20 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+
+from .builder import build_loader, build_text_transform
+from .imagenet_template import imagenet_classes, template_meta
+
+__all__ = [
+    'build_loader', build_text_transform, template_meta, imagenet_classes
+]

+ 119 - 0
datasets/base_dataset.py

@@ -0,0 +1,119 @@
+# -------------------------------------------------------------------------
+# Written by Jilan Xu
+# -------------------------------------------------------------------------
+
+import os
+# import linklink as link
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+try:
+    import mc
+except ImportError:
+    pass
+# import ceph
+from petrel_client.client import Client
+
+
+class BaseDataset(Dataset):
+    def __init__(self,
+                 root_dir,
+                 meta_file,
+                 transform=None,
+                 read_from='mc',
+                 evaluator=None):
+
+        super(BaseDataset, self).__init__()
+
+        self.root_dir = root_dir
+        self.meta_file = meta_file
+        self.transform = transform
+        self.read_from = read_from
+        self.evaluator = evaluator
+        self.initialized = False
+        if self.read_from == 'petrel':
+            self._init_petrel()
+        else:
+            raise NotImplementedError
+
+    def __len__(self):
+        """
+        Returns dataset length
+        """
+        raise NotImplementedError
+
+    def __getitem__(self, idx):
+        """
+        Get a single image data: from dataset
+
+        Arguments:
+            - idx (:obj:`int`): index of image, 0 <= idx < len(self)
+        """
+        raise NotImplementedError
+
+    def _init_petrel(self):
+        if not self.initialized:
+            self.client = Client('/mnt/petrelfs/xujilan/petreloss.conf')
+            self.initialized = True
+        
+    def read_file(self, meta_dict):
+        value = self.client.get(meta_dict['filename'])
+        filebytes = np.frombuffer(value, dtype=np.uint8)
+        return filebytes
+
+    def dump(self, writer, output):
+        """
+        Dump classification results
+
+        Arguments:
+            - writer: output stream
+            - output (:obj:`dict`): different for imagenet and custom
+        """
+        raise NotImplementedError
+
+    def merge(self, prefix):
+        """
+        Merge results into one file.
+
+        Arguments:
+            - prefix (:obj:`str`): dir/results.rank
+        """
+        world_size = link.get_world_size()
+        merged_file = prefix.rsplit('.', 1)[0] + '.all'
+        merged_fd = open(merged_file, 'w')
+        for rank in range(world_size):
+            res_file = prefix + str(rank)
+            assert os.path.exists(res_file), f'No such file or directory: {res_file}'
+            with open(res_file, 'r') as fin:
+                for line_idx, line in enumerate(fin):
+                    merged_fd.write(line)
+        merged_fd.close()
+        return merged_file
+
+    def inference(self, res_file):
+        """
+        Arguments:
+            - res_file (:obj:`str`): filename of result
+        """
+        prefix = res_file.rstrip('0123456789')
+        merged_res_file = self.merge(prefix)
+        return merged_res_file
+
+    def evaluate(self, res_file):
+        """
+        Arguments:
+            - res_file (:obj:`str`): filename of result
+        """
+        prefix = res_file.rstrip('0123456789')
+        merged_res_file = self.merge(prefix)
+        metrics = self.evaluator.eval(merged_res_file) if self.evaluator else {}
+        return metrics
+
+    def tensor2numpy(self, x):
+        if x is None:
+            return x
+        if torch.is_tensor(x):
+            return x.cpu().numpy()
+        if isinstance(x, list):
+            x = [_.cpu().numpy() if torch.is_tensor(_) else _ for _ in x]
+        return x

BIN
datasets/bpe_simple_vocab_16e6.txt.gz


+ 368 - 0
datasets/builder.py

@@ -0,0 +1,368 @@
+# -------------------------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE
+#
+# Written by Ze Liu, Zhenda Xie
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+
+import os.path as osp
+import random
+import warnings
+from functools import partial
+
+import nltk
+import numpy as np
+import torch
+import torch.distributed as dist
+import webdataset as wds
+from braceexpand import braceexpand
+from mmcv.parallel import collate
+from timm.data import create_transform
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+import timm
+if timm.__version__ == '0.6.12':
+    from timm.data.transforms import str_to_pil_interp as _pil_interp
+else:
+    from timm.data.transforms import _pil_interp
+# this works for timm==0.3.2
+# from timm.data.transforms import _pil_interp 
+from torchvision import transforms
+import torch.nn as nn
+from PIL import ImageFilter,Image
+from torch import Tensor
+from typing import Tuple, List, Optional
+import numbers
+import math
+import torchvision.transforms.functional as F
+import shutil
+
+from .formatting import ToDataContainer
+from .tokenizer import SimpleTokenizer
+from .clip_dataset import ClipDataset
+from ipdb import set_trace
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+    # The seed of each worker equals to
+    # num_worker * rank + worker_id + user_seed
+    worker_seed = num_workers * rank + worker_id + seed
+    np.random.seed(worker_seed)
+    random.seed(worker_seed)
+
+def collate_fn(batch):  
+    img = torch.stack([b['image'] for b in batch])
+    caption = torch.stack([b['caption'] for b in batch])
+    raw_caption = [b['raw_caption'] for b in batch] 
+    
+    raw_question = [b['raw_question'] for b in batch] if 'raw_question' in batch[0].keys() else None
+    raw_answer = [b['raw_answer'] for b in batch] if 'raw_answer' in batch[0].keys() else None
+
+    cross_image = torch.stack([b['cross_image'] for b in batch]) if 'cross_image' in batch[0].keys() else None
+    cross_entity = [b['cross_entity'] for b in batch] if 'cross_entity' in batch[0].keys() else None
+    
+    question = torch.stack([b['question'] for b in batch]) if 'question' in batch[0].keys() and batch[0]['question'] is not None else None
+    answer = torch.stack([b['answer'] for b in batch]) if 'answer' in batch[0].keys() and batch[0]['answer'] is not None else None
+        
+    return {    
+        'image':img,
+        'caption':caption,
+        'raw_caption' : raw_caption,
+        'raw_question': raw_question,
+        'raw_answer': raw_answer,
+        
+        'cross_image': cross_image,
+        'cross_entity': cross_entity, 
+        
+        'question': question,
+        'answer': answer,
+        
+    }
+
+def build_loader(config):
+    local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0
+
+    dataset_train = build_dataset(is_train=True, config=config)
+    print(f'local rank {local_rank} / global rank {dist.get_rank()} \
+        successfully build train dataset')
+    dataset_val = build_dataset(is_train=False, config=config)
+    print(f'local rank {local_rank} / global rank {dist.get_rank()} \
+        successfully build val dataset')
+
+    sampler_train = torch.utils.data.DistributedSampler(dataset_train, shuffle=True)        
+    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+    print('train batch size: ', config.train.batch_size)
+    print('val batch size: ', config.val.batch_size)
+    data_loader_train = torch.utils.data.DataLoader(
+        dataset_train,
+        sampler=sampler_train,
+        batch_size=config.train.batch_size,
+        num_workers=config.num_workers,
+        pin_memory=True,
+        drop_last=True,
+        persistent_workers=True,
+        collate_fn=collate_fn, ### NOTEL THIS ###
+        #shuffle=False,
+    )
+
+    data_loader_val = torch.utils.data.DataLoader(
+        dataset_val,
+        sampler=sampler_val,
+        batch_size=config.val.batch_size,
+        num_workers=config.val.num_workers,
+        pin_memory=True,
+        drop_last=False,
+        persistent_workers=True,
+    )
+    return dataset_train, dataset_val, data_loader_train, data_loader_val
+
+def build_dataset(is_train, config):
+    img_transform = build_img_transform(is_train, config.img_aug, config.with_dc)
+    text_transform = build_text_transform(is_train, config.text_aug, config.with_dc)
+    split = 'train' if is_train else 'val'
+
+    image_reader = config[split].get('image_reader', {})
+    dataset = ClipDataset(
+        root_dir=config[split]['root_dir'],
+        meta_file=config[split]['meta_file'],
+        img_transform=img_transform,
+        text_transform=text_transform,
+        read_from=config[split]['read_from'],
+        evaluator=None, # no evaluator for now
+        image_reader_type=image_reader.get('type', 'pil'),
+        fseek=config[split].get('fseek',False),
+        split=split,
+        cross_image=config[split].get('cross_image', False),
+        mask_type=config[split].get('mask_type', 'class'),
+        use_distilbert=config[split].get('use_distilbert', True),
+        class_label_dir=config[split].get('class_label_dir', None),
+        sample_list_dir=config[split].get('sample_list_dir', None),
+    )
+    print('dataset len: ', len(dataset))
+    return dataset
+
+def build_img_transform(is_train, config, with_dc=True):
+    if not config.deit_aug:
+        if is_train:
+            transform = transforms.Compose([
+                transforms.RandomResizedCrop(config.img_size, scale=config.img_scale),
+                transforms.RandomHorizontalFlip(),
+                transforms.ToTensor(),
+                transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
+            ])
+        else:
+            transform = transforms.Compose([
+                transforms.Resize(config.img_size + 32),
+                transforms.CenterCrop(config.img_size),
+                transforms.ToTensor(),
+                transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
+            ])
+        return transform
+
+    if is_train:
+        # this should always dispatch to transforms_imagenet_train
+        transform = create_transform(
+            input_size=config.img_size,
+            is_training=True,
+            color_jitter=config.color_jitter if config.color_jitter > 0 else None,
+            auto_augment=config.auto_augment if config.auto_augment != 'none' else None,
+            re_prob=config.re_prob,
+            re_mode=config.re_mode,
+            re_count=config.re_count,
+            interpolation=config.interpolation,
+        )
+    else:
+        size = int((256 / 224) * config.img_size)
+        transform = transforms.Compose([
+            transforms.Resize(size, interpolation=_pil_interp(config.interpolation)),
+            transforms.CenterCrop(config.img_size),
+            transforms.ToTensor(),
+            transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
+        ])
+    if with_dc:
+        transform = transforms.Compose([*transform.transforms, ToDataContainer()])
+
+    return transform
+
+def build_text_transform(is_train, config, with_dc=True):
+    local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0
+    if is_train:
+        ### only on local rank 0 ###
+        if local_rank == 0:
+            ### download itself or pre-download and give the nltk dir ###
+            # nltk.download('popular')
+            nltk.data.path.append('/mnt/petrelfs/xujilan/nltk_data')
+            
+        transform = WordAugTokenizeWrapper(
+            Tokenize(SimpleTokenizer(), max_seq_len=config.max_seq_len),
+            max_word=config.multi_label,
+            word_type=config.word_type)
+    else:
+        transform = Tokenize(SimpleTokenizer(), max_seq_len=config.max_seq_len)
+            
+    if with_dc:
+        transform = transforms.Compose([transform, ToDataContainer()])
+    return transform
+    
+class Tokenize:
+
+    def __init__(self, tokenizer, max_seq_len=77, truncate=True):
+        self.tokenizer = tokenizer
+        self.max_seq_len = max_seq_len
+        self.truncate = truncate
+
+    def __call__(self, texts):
+        expanded_dim = False
+        if isinstance(texts, str):
+            texts = [texts]
+            expanded_dim = True
+        
+        sot_token = self.tokenizer.encoder['<|startoftext|>']
+        eot_token = self.tokenizer.encoder['<|endoftext|>']
+        all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
+        result = torch.zeros(len(all_tokens), self.max_seq_len, dtype=torch.long)
+
+        for i, tokens in enumerate(all_tokens):
+            if len(tokens) > self.max_seq_len:
+                if self.truncate:
+                    tokens = tokens[:self.max_seq_len]
+                    tokens[-1] = eot_token
+                else:
+                    raise RuntimeError(f'Input {texts[i]} is too long for context length {self.max_seq_len}')
+            result[i, :len(tokens)] = torch.tensor(tokens)
+
+        if expanded_dim:
+            return result[0]
+
+        return result
+
+
+class WordAugTokenizeWrapper:
+
+    def __init__(self, tokenize, max_word=3, template_set='full', word_type='noun'):
+        self.tokenize = tokenize
+        self.max_word = max_word
+        from .imagenet_template import (full_imagenet_templates, sub_imagenet_template, simple_imagenet_template,
+                                        identity_template)
+        assert template_set in ['full', 'subset', 'simple', 'identity']
+        if template_set == 'full':
+            templates = full_imagenet_templates
+        elif template_set == 'subset':
+            templates = sub_imagenet_template
+        elif template_set == 'simple':
+            templates = simple_imagenet_template
+        elif template_set == 'identity':
+            templates = identity_template
+        else:
+            raise ValueError
+        self.templates = templates
+        assert word_type in ['noun', 'noun_phrase']
+        self.word_type = word_type
+
+    def get_tag(self, tokenized, tags):
+        if not isinstance(tags, (list, tuple)):
+            tags = [tags]
+        ret = []
+        for (word, pos) in nltk.pos_tag(tokenized):
+            for tag in tags:
+                if pos == tag:
+                    ret.append(word)
+        return ret
+
+    def get_tag_with_loc(self, tokenized, tags):
+        if not isinstance(tags, (list, tuple)):
+            tags = [tags]
+        ret = []
+        loc = []
+        for i, (word, pos) in enumerate(nltk.pos_tag(tokenized)):
+            for tag in tags:
+                if pos == tag:
+                    ret.append(word)
+                    loc.append(i)
+        return ret, loc
+    
+    def get_noun_phrase(self, tokenized):
+        # Taken from Su Nam Kim Paper...
+        grammar = r"""
+            NBAR:
+                {<NN.*|JJ>*<NN.*>}  # Nouns and Adjectives, terminated with Nouns
+
+            NP:
+                {<NBAR>}
+                {<NBAR><IN><NBAR>}  # Above, connected with in/of/etc...
+        """
+        chunker = nltk.RegexpParser(grammar)
+
+        chunked = chunker.parse(nltk.pos_tag(tokenized))
+        continuous_chunk = []
+        current_chunk = []
+
+        for subtree in chunked:
+            if isinstance(subtree, nltk.Tree):
+                current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
+            elif current_chunk:
+                named_entity = ' '.join(current_chunk)
+                if named_entity not in continuous_chunk:
+                    continuous_chunk.append(named_entity)
+                    current_chunk = []
+            else:
+                continue
+
+        return continuous_chunk
+
+    def __call__(self, text):
+        """
+        Args:
+            text: str
+        
+        """
+        assert isinstance(text, str)
+        tokenized = nltk.word_tokenize(text)
+        
+        nouns = []
+        if len(tokenized) > 0:
+            if self.word_type == 'noun':
+                # nouns = self.get_tag(tokenized, ['NN', 'NNS', 'NNP', 'VBG', 'VB', 'VBD', 'VBN', 'VBP', 'VBZ'])
+                # nouns = self.get_tag(tokenized, ['NN', 'NNS'])
+                # nouns, locs = self.get_tag_with_loc(tokenized, ['NN', 'NNS'])
+                nouns, locs = self.get_tag_with_loc(tokenized, ['NN', 'NNS', 'NNP',])
+            elif self.word_type == 'noun_phrase':
+                nouns = self.get_noun_phrase(tokenized)
+            else:
+                raise ValueError('word_type must be noun or noun_phrase')
+        
+        ### By default, we use this ###
+        if self.max_word == 0:
+            return self.tokenize(text), nouns, locs, text
+        
+        prompt_texts = []
+        if len(nouns) > 0:
+            select_nouns = np.random.choice(nouns, min(self.max_word, len(nouns)), replace=False)
+            prompt_texts = [np.random.choice(self.templates).format(noun) for noun in select_nouns]
+        if len(prompt_texts) < self.max_word:
+            prompt_texts += [text] * (self.max_word - len(prompt_texts))
+
+        texts = [text] + prompt_texts
+        return self.tokenize(texts), nouns, locs, texts

+ 457 - 0
datasets/clip_dataset.py

@@ -0,0 +1,457 @@
+# -------------------------------------------------------------------------
+# Written by Jilan Xu
+# -------------------------------------------------------------------------
+
+from re import L
+import torch
+import json
+import os.path as osp
+import requests
+import numpy as np
+import time
+from typing import List
+from .base_dataset import BaseDataset
+# from prototype.data.image_reader import build_image_reader
+from .image_reader import build_image_reader
+# import linklink as link
+import random
+import os
+import omegaconf
+import clip
+from ipdb import set_trace
+from .tokenizer import SimpleTokenizer
+from .imagenet_template import full_imagenet_templates
+from nltk.stem import WordNetLemmatizer
+lemmatizer = WordNetLemmatizer()
+
+### frequently appeared 100 entities ###
+TOP_CLASSES_1=[
+    'people', 'man', 'men', 'woman', 'women', 'girl', 'boy', 'lady', 'kid', 'child', 'children', 'baby', 'student', 'bride', 'groom', 'couple', 'prince', 'princess', \
+    'car', 'bus', 'truck', 'motorcycle', 'train', 'bicycle', 'boat', 'aeroplane', 'airplane', 'motorbike', 'bike',\
+    'cup', 'bottle', 'bowl', 'knife', 'spoon',  'glass', 'fork',\
+    'chair', 'table', 'bench', 'clock', 'laptop', 'light', 'vase', 'plant', 'remote', 'microwave', 'toaster', 'oven','mouse', 'keyboard','sofa', 'monitor','desk', 'tv','TV', 'couch', 'flower','refrigerator', \
+    'house', 'building', 'hotel',\
+    'handbag', 'umbrella','book', 'backpack', 'phone', 'shirt', 'tie', 'suitcase','T-shirt', 'bag',  'box', \
+    'sink','bed','toilet',\
+    'cat','dog',  'horse', 'bird','cow', 'sheep' ,'elephant', 'bear', 'zebra', 'giraffe', \
+    'ball', 'racket', 'skateboard', 'skis', 'snowboard', 'surfboard', 'kite', \
+    'pizza', 'cake', 'apple', 'banana', 'sandwich', 'orange', 'carrot', 'donut' ,\
+]
+
+### some of the entities are similar, map them to a single one ###
+syn_dict = {
+    'people':'people', 'man':'people', 'men':'people', 'woman':'people', 'women':'people', 'girl':'people', 'boy':'people', 'lady':'people', 'kid':'people', 'child':'people', 'children':'people', 'baby':'people', 'student':'people', 'bride':'people', 'groom':'people', 'couple':'people', 'prince':'people', 'princess':'people',\
+    'airplane': 'aeroplane','motorbike': 'motorcycle','bike': 'bicycle',\
+    'TV':'tv', 'desk': 'table', 'couch':'sofa',\
+    'building': 'house', 'hotel': 'house', \
+    'T-shirt': 'shirt','T-Shirt': 'shirt', 'handbag': 'bag', \
+}
+
+### unique entities ###
+TOP_UNIQUE_CLASSES = [
+    'people', 'car', 'bus', 'truck', 'motorcycle', \
+    'train', 'bicycle', 'boat', 'aeroplane', 'cup', \
+    'bottle', 'bowl', 'knife', 'spoon',  'glass', \
+    'fork', 'chair', 'table', 'bench', 'clock', \
+    'laptop', 'light', 'vase', 'plant', 'remote',\
+    'microwave', 'toaster', 'oven','mouse', 'keyboard',\
+    'sofa', 'monitor', 'tv', 'flower','refrigerator', \
+    'house', 'bag', 'umbrella','book', 'backpack', \
+    'phone', 'shirt', 'tie', 'suitcase', 'box',\
+    'sink','bed','toilet', 'cat','dog', \
+    'horse', 'bird','cow', 'sheep' ,'elephant', \
+    'bear', 'zebra', 'giraffe',  'ball', 'racket', \
+    'skateboard', 'skis', 'snowboard', 'surfboard', 'kite',\
+    'pizza', 'cake', 'apple', 'banana', 'sandwich',\
+    'orange', 'carrot', 'donut' ,\
+]
+
+TOP_UNIQUE_CLASSES_IDX = {}
+for i, x in enumerate(TOP_UNIQUE_CLASSES):
+    TOP_UNIQUE_CLASSES_IDX[x] = i
+
+class ClipDataset(BaseDataset):
+    """
+    Clip Dataset.
+
+    Arguments:
+        - root_dir (:obj:`str`): root directory of dataset
+        - meta_file (:obj:`str`): name of meta file
+        - transform (list of ``Transform`` objects): list of transforms
+        - read_from (:obj:`str`): read type from the original meta_file
+        - evaluator (:obj:`Evaluator`): evaluate to get metrics
+        - image_reader_type (:obj:`str`): reader type 'pil' or 'ks'
+        - osg_server (:obj:`str`): '10.198.3.28:30080/components/osg-default/v1'
+        - topnoun: 'none' / 'coco_top50' / 'cc3m_top50' / ...
+    Metafile example::
+        "{"filename": "n01440764/n01440764_10026.JPEG", "label": 0, "label_name": "dog"}\n"
+    """
+
+    def __init__(self, root_dir, meta_file, img_transform=None, text_transform=None,
+                 read_from='mc', evaluator=None, image_reader_type='pil',
+                 fseek=False, label_texts_ensemble='none', split='train',
+                 cross_image=False, use_entity=True, mask_type='class', use_distilbert=True, class_label_dir=None, sample_list_dir=None,
+                 ):
+        if not isinstance(meta_file, List) and not isinstance(meta_file, omegaconf.listconfig.ListConfig):
+            meta_file = [meta_file]
+        if not isinstance(root_dir, List) and not isinstance(meta_file, omegaconf.listconfig.ListConfig):
+            root_dir = [root_dir]
+
+        self.meta_file = meta_file
+        self.root_dir = root_dir
+        self.read_from = read_from
+        self.img_transform = img_transform
+        self.text_transform = text_transform
+        self.evaluator = evaluator
+        self.image_reader = build_image_reader(image_reader_type)
+
+        self.fseek = fseek
+        self.initialized = False
+        self.label_texts_ensemble = label_texts_ensemble
+        self.num = 0
+        self.split=split
+
+        self.cross_image = cross_image
+        self.use_entity = use_entity
+        self.tokenizer = SimpleTokenizer()
+        self.mask_type = mask_type
+        self.use_distilbert = use_distilbert        
+        if self.cross_image:
+            self._load_meta_class_dict(class_label_dir, sample_list_dir)
+
+        self.metas = []
+
+        ### fseek uses file seek to load each line with pointer online ###
+        ### this saves the memory while adding the loading time ###
+        if self.fseek:
+            self.line_offsets = []
+            for each_meta_file in meta_file:
+                line_offset = []
+                offset = 0
+                with open(each_meta_file) as f:
+                    for line in f:
+                        line_offset.append(offset)
+                        offset += len(line.encode('UTF-8'))
+                    f.close()
+                self.num += len(line_offset)
+                self.line_offsets.append(line_offset)
+        else:
+            ### read from local file and load all metafile info ###
+            for rd, each_meta_file in zip(root_dir, meta_file):
+                with open(each_meta_file) as f:
+                    lines = f.readlines()
+                self.num += len(lines)
+
+                for line in lines:
+                    info = json.loads(line)
+                    filename = osp.join(rd, info['filename'])
+                    ### add root_dir to filename ###
+                    info['filename'] = filename
+                    self.metas.append(info)
+
+        super(ClipDataset, self).__init__(root_dir=root_dir,
+                                          meta_file=meta_file,
+                                          read_from=read_from,
+                                          transform=img_transform,
+                                          evaluator=evaluator)
+
+
+    def __len__(self):        
+        return self.num
+
+    def _str2list(self, x):
+        if type(x) is list:
+            return x
+        elif type(x) is str:
+            return [x]
+        else:
+            raise RuntimeError(
+                "unknown value for _str2list: {}".format(type(x)))
+
+    def _load_meta(self, idx):
+        if self.fseek:
+            source_id = 0
+            while idx >= len(self.line_offsets[source_id]):
+                idx -= len(self.line_offsets[source_id])
+                source_id += 1 #fixed
+            with open(self.meta_file[source_id]) as f:
+                f.seek(self.line_offsets[source_id][idx])
+                line = f.readline()
+                meta = json.loads(line)
+                filename = osp.join(self.root_dir[source_id], meta['filename'])
+                meta['filename'] = filename
+                f.close()
+            return meta
+        else:
+            return self.metas[idx]
+        
+    def _load_meta_class_dict(self, class_label_dir, sample_list_dir):
+        # load class dict which is used to sample cross_image
+        with open(sample_list_dir) as f:
+            lines = f.readline()
+            self.class_dict = json.loads(lines)
+
+        # load class label for each sample    
+        with open(class_label_dir) as f:
+            lines = f.readline()
+            self.class_label = json.loads(lines)
+                
+    def sample_cross_image(self, curr_cls):
+        class_list = self.class_dict[curr_cls]
+        filename, caption = random.choice(class_list)
+        # curr_meta = self._load_meta(idx)
+        # filename = curr_meta['filename']
+        filename = osp.join(self.root_dir[0], filename)
+        curr_meta = {'filename':filename, 'caption':caption}
+        img_bytes = self.read_file(curr_meta)
+        img = self.image_reader(img_bytes, filename)
+        caption = curr_meta['caption'] if 'caption' in curr_meta else ''
+        raw_caption = curr_meta['caption'] if 'caption' in curr_meta else ''
+        caption, nouns, locs, _ = self.text_transform(caption)
+        return img, caption, raw_caption
+
+
+    def __getitem__(self, idx):
+        curr_meta = self._load_meta(idx)
+        filename = curr_meta['filename']
+
+        label = int(curr_meta['label']) if 'label' in curr_meta else -1
+        label_name = curr_meta['label_name'] if 'label_name' in curr_meta else None
+        caption = curr_meta['caption'] if 'caption' in curr_meta else ''
+        
+        raw_caption = curr_meta['caption']
+        tag = self._str2list(curr_meta['tag']) if 'tag' in curr_meta else []
+        ret_info = {}
+
+        #############
+
+        try:
+            assert self.is_contains_chinese(caption) == False
+            img_bytes = self.read_file(curr_meta)
+        
+            img = self.image_reader(img_bytes, filename)
+            if self.img_transform is not None:
+                image = self.img_transform(img)
+                    
+            if self.text_transform is not None:
+                if self.split == 'train':
+                    ### for clip TextTransformer, captions are here tokenised ###
+                    ### for bert/distilbert, text transform are used to select nouns, captions will be tokensized later ###
+                    caption, nouns, locs, prompt_texts = self.text_transform(caption)
+                    
+                    if self.use_entity:
+                        if self.use_distilbert:
+                            ### bert/distilbert-like, questions/answers will be tokenised later ###
+                            raw_question, question, raw_answer, answer = self.build_question_and_answer_for_distilbert(raw_caption, nouns)
+                        else: 
+                            ### clip TextTransformer-like, questions/answers are tokenised ###
+                            raw_question, question, raw_answer, answer = self.build_question_and_answer(raw_caption, nouns)
+                
+                        ret_info['question'] = question
+                        ret_info['answer'] = answer
+                        ret_info['raw_question'] = raw_question
+                        ret_info['raw_answer'] = raw_answer
+
+
+                    if self.cross_image:
+                        imgname = filename.split('/')[-1]
+                        top100_label = self.class_label[imgname] # the label is str, due to some issues
+                        crossimg, crosscaption, cross_rawcaption = self.sample_cross_image(top100_label) 
+                        # crossimg = tensor_trans(trans(crossimg))
+                        crossimg = self.img_transform(crossimg)
+                        
+                        cross_entity = 'A photo of ' + TOP_UNIQUE_CLASSES[int(top100_label)]
+                        ret_info['cross_image'] = crossimg
+                        ret_info['cross_entity'] = cross_entity
+                else:
+                    caption = self.text_transform(caption)
+            
+            ret_info['image'] = image
+            ret_info['caption'] = caption
+            ret_info['target'] = label
+            ret_info['raw_caption'] = raw_caption
+            # ret_info['filename'] = filename
+            return ret_info    
+                        
+        except Exception as e:          
+            print(e)
+            # return self.__getitem__(0)
+    
+    # def judge_noun(self, n):
+    #     n = n.replace('.', '')
+    #     ans = n.split("'s")[0].split(',')[0]
+    #     ### conduct Lemmatization ###
+    #     # ans = nlp(ans)[0].lemma_
+        
+    #     if ans in syn_dict:
+    #         ans = syn_dict[ans]
+    #     elif len(ans) >= 2 and ans[-2:] == 'es' and ans[:-2] in syn_dict:
+    #         ans = syn_dict[ans[:-2]]    
+    #     elif len(ans) >= 1 and ans[-1] == 's' and ans[:-1] in syn_dict:
+    #         ans = syn_dict[ans[:-1]]
+    #     elif ans.lower() in syn_dict:
+    #         ans = syn_dict[ans.lower()]
+    #     elif len(ans) >= 2 and ans[-2:] == 'es' and ans.lower()[:-2] in syn_dict:
+    #         ans = syn_dict[ans.lower()[:-2]]
+    #     elif len(ans) >= 1 and ans[-1] == 's' and ans.lower()[:-1] in syn_dict:
+    #         ans = syn_dict[ans.lower()[:-1]]
+
+    #     if ans in TOP_UNIQUE_CLASSES:
+    #         return 1, ans
+    #     elif len(ans) >= 2 and ans[-2:] == 'es' and ans[:-2] in TOP_UNIQUE_CLASSES:
+    #         return 1, ans[:-2]
+    #     elif len(ans) >= 1 and ans[-1] == 's' and ans[:-1] in TOP_UNIQUE_CLASSES:
+    #         return 1, ans[:-1]
+    #     elif ans.lower() in TOP_UNIQUE_CLASSES:
+    #         return 1, ans.lower()
+    #     elif len(ans) >= 2 and ans.lower()[-2:] == 'es' and ans.lower()[:-2] in TOP_UNIQUE_CLASSES:
+    #         return 1, ans.lower()[:-2]
+    #     elif len(ans) >= 1 and ans.lower()[-1] == 's' and ans.lower()[:-1] in TOP_UNIQUE_CLASSES:
+    #         return 1, ans.lower()[:-1]
+    #     return 0, n
+    
+    def judge_noun(self, n):
+        n = n.replace('.', '')
+        # ans = n.split("'s")[0].split(',')[0]
+        # ans = n.strip("'s").strip(",")
+        ans = n
+        ### conduct Lemmatization ###
+        # ans = nlp(ans.lower())[0].lemma_
+        ans = lemmatizer.lemmatize(ans.lower())
+        
+        if ans in syn_dict:
+            ans = syn_dict[ans]
+        
+        if ans in TOP_UNIQUE_CLASSES:
+            return 1, ans
+        return 0, n       
+    
+    def build_question_and_answer(self, caption, nouns):
+        words = caption.split(' ')
+        question = ''
+        ans_list = []
+
+        token_mapper = {}
+        word_mapper = {}
+        assert self.mask_type == 'class'
+        for word in words:
+            word_after = word
+            word_flag, newword = self.judge_noun(word)
+            if word_flag == 1:
+                question = question + newword + ' '
+                ans_list.append(newword)
+                token_id = self.tokenizer.encode(newword)[0]
+                token_mapper[token_id] = TOP_UNIQUE_CLASSES_IDX[newword]
+                word_mapper[token_id] = 332   ### this is 'M'
+            else:
+                question = question + word + ' '
+                    
+        question = question.replace("'", '').strip()
+        raw_question = question
+        
+        question, _, _, _ = self.text_transform(raw_question)
+        question = torch.tensor([word_mapper[int(word)] if int(word) in word_mapper else word for word in question])
+        # raw_answer = 'A photo of ' + ' and '.join(list(set(ans_list))) ## unique words
+        raw_answer = random.choice(full_imagenet_templates).split('{}')[0] + ' and '.join(list(set(ans_list)))
+        answer, _, _, _ = self.text_transform(raw_answer)
+        
+        return raw_question, question, raw_answer, answer
+
+
+    def build_question_and_answer_for_distilbert(self, caption, nouns):
+        words = caption.split(' ')
+        question = ''
+        entity_list = []
+
+        ### default, mask all entites ###
+        assert self.mask_type == 'class'
+        for word in words:
+            word_after = word
+            word_flag, newword = self.judge_noun(word)
+            if word_flag == 1:
+                question = question + '[MASK]' + ' '
+                entity_list.append(newword)
+            else:
+                question = question + word + ' '
+    
+        question = question.replace("'", '').strip()
+        raw_question = question
+        #### build and transform answers ###
+        # raw_answer = 'A photo of ' + ' and '.join(list(set(ans_list))) ## unique words
+        raw_answer = random.choice(full_imagenet_templates).split('{}')[0] + ' and '.join(list(set(entity_list)))    
+        return raw_question, None, raw_answer, None
+
+    def is_contains_chinese(self, strs):
+        for _char in strs:
+            if '\u4e00' <= _char <= '\u9fa5':
+                return True
+        return False
+
+    def _get_label_text(self, text):
+        # label_text = ['a photo of ' + text + '.']
+        if self.label_texts_ensemble == 'prompt6':
+            f = f'{osp.abspath(os.getcwd())}/../../prototype/data/datasets/prompts/query_pattern_prompt6'
+        elif self.label_texts_ensemble == 'prompt8':
+            f = f'{osp.abspath(os.getcwd())}/../../prototype/data/datasets/prompts/query_pattern_prompt8'
+        elif self.label_texts_ensemble == 'prompt80':
+            f = f'{osp.abspath(os.getcwd())}/../../prototype/data/datasets/prompts/query_pattern_prompt80'
+        elif self.label_texts_ensemble == 'cc':
+            return [text]
+        else:
+            f = f'{osp.abspath(os.getcwd())}/../../prototype/data/datasets/prompts/query_pattern_prompt1'
+        label_text = []
+        with open(f) as fin:
+            for line in fin.readlines():
+                label_text.append(line.replace('{0}', text))
+        return label_text
+
+    def get_label_texts(self,):
+        label_to_name = {}
+        for curr_meta in self.metas:
+            label = int(curr_meta['label']) if 'label' in curr_meta else None
+            label_name = curr_meta['label_name'] if 'label_name' in curr_meta else None
+            if label is not None and label_name is not None:
+                label_to_name[label] = label_name
+        labels = list(label_to_name.keys())
+        labels.sort()
+
+        label_texts = []
+        label_text_len = []
+        for label in labels:
+            label_name = label_to_name[label]
+            label_text = self._get_label_text(label_name)
+            label_texts.extend(label_text)
+            label_text_len.append(len(label_text))
+
+        all_len = sum(label_text_len)
+        offset = 0
+        label_num = len(labels)
+        label_texts_ensemble_matrix = torch.zeros(all_len, label_num)
+        for lbl, ltl in enumerate(label_text_len):
+            label_texts_ensemble_matrix[offset: offset + ltl, lbl] = 1
+            offset += ltl
+
+        return label_texts, label_texts_ensemble_matrix
+
+    def dump(self, writer, output):
+        filenames = output['filenames']
+        image_ids = output['image_ids']
+        label_names = output['label_names']
+        captions = output['captions']
+        tags = output['tags']
+        prediction = self.tensor2numpy(output['prediction'])
+        score = self.tensor2numpy(output['score'])
+        labels = self.tensor2numpy(output['labels'])
+        for _idx in range(len(filenames)):
+            res = {
+                'image_id': int(image_ids[_idx]),
+                'filename': filenames[_idx],
+                'label': int(labels[_idx]),
+                'label_name': label_names[_idx],
+                'caption': captions[_idx],
+                'tag': tags[_idx],
+                'prediction': int(prediction[_idx]),
+                'score': [float('%.8f' % s) for s in score[_idx]]
+            }
+            writer.write(json.dumps(res, ensure_ascii=False) + '\n')
+        writer.flush()

+ 36 - 0
datasets/formatting.py

@@ -0,0 +1,36 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import torch
+from mmcv.parallel import DataContainer as DC
+
+
+class ToDataContainer(object):
+    """Convert results to :obj:`mmcv.DataContainer`"""
+
+    def __call__(self, sample):
+        """Call function to convert data in results to
+        :obj:`mmcv.DataContainer`.
+
+        Args:
+            sample (torch.Tensor): Input sample.
+
+        Returns:
+            DataContainer
+        """
+        if isinstance(sample, int):
+            sample = torch.tensor(sample)
+        return DC(sample, stack=True, pad_dims=None)
+
+    def __repr__(self):
+        return self.__class__.__name__

+ 46 - 0
datasets/image_reader.py

@@ -0,0 +1,46 @@
+# -------------------------------------------------------------------------
+# Written by Jilan Xu
+# -------------------------------------------------------------------------
+
+import io
+from PIL import Image
+import logging
+import kestrel as ks
+from PIL import ImageFile
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+logger = logging.getLogger('global')
+
+
+
+def pil_loader(img_bytes, filepath):
+    buff = io.BytesIO(img_bytes)
+    try:
+        with Image.open(buff) as img:
+            img = img.convert('RGB')
+    except IOError:
+        logger.info('Failed in loading {}'.format(filepath))
+    return img
+
+
+def kestrel_loader(img_bytes, filepath):
+    input_frame = ks.Frame()
+    try:
+        image_data = img_bytes.tobytes()
+        input_frame.create_from_mem(image_data, len(image_data))
+        if input_frame.frame_type != ks.KESTREL_VIDEO_RGB:
+            input_frame = input_frame.cvt_color(ks.KESTREL_VIDEO_RGB)
+        if ks.Device().mem_type() == ks.KESTREL_MEM_DEVICE:
+            input_frame = input_frame.upload()
+    except IOError:
+        logger.info('Failed in loading {}'.format(filepath))
+    return [input_frame]
+
+
+def build_image_reader(reader_type):
+    if reader_type == 'pil':
+        return pil_loader
+    elif reader_type == 'kestrel':
+        return kestrel_loader
+    else:
+        raise NotImplementedError

+ 267 - 0
datasets/imagenet_template.py

@@ -0,0 +1,267 @@
+# -------------------------------------------------------------------------
+# MIT License
+#
+# Copyright (c) 2021 OpenAI
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+full_imagenet_templates = [
+    'a bad photo of a {}.',
+    'a photo of many {}.',
+    'a sculpture of a {}.',
+    'a photo of the hard to see {}.',
+    'a low resolution photo of the {}.',
+    'a rendering of a {}.',
+    'graffiti of a {}.',
+    'a bad photo of the {}.',
+    'a cropped photo of the {}.',
+    'a tattoo of a {}.',
+    'the embroidered {}.',
+    'a photo of a hard to see {}.',
+    'a bright photo of a {}.',
+    'a photo of a clean {}.',
+    'a photo of a dirty {}.',
+    'a dark photo of the {}.',
+    'a drawing of a {}.',
+    'a photo of my {}.',
+    'the plastic {}.',
+    'a photo of the cool {}.',
+    'a close-up photo of a {}.',
+    'a black and white photo of the {}.',
+    'a painting of the {}.',
+    'a painting of a {}.',
+    'a pixelated photo of the {}.',
+    'a sculpture of the {}.',
+    'a bright photo of the {}.',
+    'a cropped photo of a {}.',
+    'a plastic {}.',
+    'a photo of the dirty {}.',
+    'a jpeg corrupted photo of a {}.',
+    'a blurry photo of the {}.',
+    'a photo of the {}.',
+    'a good photo of the {}.',
+    'a rendering of the {}.',
+    'a {} in a video game.',
+    'a photo of one {}.',
+    'a doodle of a {}.',
+    'a close-up photo of the {}.',
+    'a photo of a {}.',
+    'the origami {}.',
+    'the {} in a video game.',
+    'a sketch of a {}.',
+    'a doodle of the {}.',
+    'a origami {}.',
+    'a low resolution photo of a {}.',
+    'the toy {}.',
+    'a rendition of the {}.',
+    'a photo of the clean {}.',
+    'a photo of a large {}.',
+    'a rendition of a {}.',
+    'a photo of a nice {}.',
+    'a photo of a weird {}.',
+    'a blurry photo of a {}.',
+    'a cartoon {}.',
+    'art of a {}.',
+    'a sketch of the {}.',
+    'a embroidered {}.',
+    'a pixelated photo of a {}.',
+    'itap of the {}.',
+    'a jpeg corrupted photo of the {}.',
+    'a good photo of a {}.',
+    'a plushie {}.',
+    'a photo of the nice {}.',
+    'a photo of the small {}.',
+    'a photo of the weird {}.',
+    'the cartoon {}.',
+    'art of the {}.',
+    'a drawing of the {}.',
+    'a photo of the large {}.',
+    'a black and white photo of a {}.',
+    'the plushie {}.',
+    'a dark photo of a {}.',
+    'itap of a {}.',
+    'graffiti of the {}.',
+    'a toy {}.',
+    'itap of my {}.',
+    'a photo of a cool {}.',
+    'a photo of a small {}.',
+    'a tattoo of the {}.',
+]
+
+sub_imagenet_template = [
+    'itap of a {}.', 'a bad photo of a {}.', 'a origami {}.', 'a photo of the large {}.', 'a {} in a video game.',
+    'art of the {}.', 'a photo of the small {}.'
+]
+
+simple_imagenet_template = [
+    'a photo of a {}.',
+]
+
+identity_template = [
+    '{}',
+]
+
+template_meta = {
+    'full': full_imagenet_templates,
+    'subset': sub_imagenet_template,
+    'simple': simple_imagenet_template,
+    'identity': identity_template,
+}
+
+imagenet_classes = [
+    'tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead shark', 'electric ray', 'stingray', 'rooster',
+    'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'American robin', 'bulbul',
+    'jay', 'magpie', 'chickadee', 'American dipper', 'kite (bird of prey)', 'bald eagle', 'vulture', 'great grey owl',
+    'fire salamander', 'smooth newt', 'newt', 'spotted salamander', 'axolotl', 'American bullfrog', 'tree frog',
+    'tailed frog', 'loggerhead sea turtle', 'leatherback sea turtle', 'mud turtle', 'terrapin', 'box turtle',
+    'banded gecko', 'green iguana', 'Carolina anole', 'desert grassland whiptail lizard', 'agama',
+    'frilled-necked lizard', 'alligator lizard', 'Gila monster', 'European green lizard', 'chameleon', 'Komodo dragon',
+    'Nile crocodile', 'American alligator', 'triceratops', 'worm snake', 'ring-necked snake', 'eastern hog-nosed snake',
+    'smooth green snake', 'kingsnake', 'garter snake', 'water snake', 'vine snake', 'night snake', 'boa constrictor',
+    'African rock python', 'Indian cobra', 'green mamba', 'sea snake', 'Saharan horned viper',
+    'eastern diamondback rattlesnake', 'sidewinder rattlesnake', 'trilobite', 'harvestman', 'scorpion',
+    'yellow garden spider', 'barn spider', 'European garden spider', 'southern black widow', 'tarantula', 'wolf spider',
+    'tick', 'centipede', 'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie grouse', 'peafowl', 'quail',
+    'partridge', 'african grey parrot', 'macaw', 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater',
+    'hornbill', 'hummingbird', 'jacamar', 'toucan', 'duck', 'red-breasted merganser', 'goose', 'black swan', 'tusker',
+    'echidna', 'platypus', 'wallaby', 'koala', 'wombat', 'jellyfish', 'sea anemone', 'brain coral', 'flatworm',
+    'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton', 'chambered nautilus', 'Dungeness crab', 'rock crab',
+    'fiddler crab', 'red king crab', 'American lobster', 'spiny lobster', 'crayfish', 'hermit crab', 'isopod',
+    'white stork', 'black stork', 'spoonbill', 'flamingo', 'little blue heron', 'great egret', 'bittern bird',
+    'crane bird', 'limpkin', 'common gallinule', 'American coot', 'bustard', 'ruddy turnstone', 'dunlin',
+    'common redshank', 'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross', 'grey whale',
+    'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese Chin', 'Maltese', 'Pekingese', 'Shih Tzu',
+    'King Charles Spaniel', 'Papillon', 'toy terrier', 'Rhodesian Ridgeback', 'Afghan Hound', 'Basset Hound', 'Beagle',
+    'Bloodhound', 'Bluetick Coonhound', 'Black and Tan Coonhound', 'Treeing Walker Coonhound', 'English foxhound',
+    'Redbone Coonhound', 'borzoi', 'Irish Wolfhound', 'Italian Greyhound', 'Whippet', 'Ibizan Hound',
+    'Norwegian Elkhound', 'Otterhound', 'Saluki', 'Scottish Deerhound', 'Weimaraner', 'Staffordshire Bull Terrier',
+    'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier', 'Kerry Blue Terrier', 'Irish Terrier',
+    'Norfolk Terrier', 'Norwich Terrier', 'Yorkshire Terrier', 'Wire Fox Terrier', 'Lakeland Terrier',
+    'Sealyham Terrier', 'Airedale Terrier', 'Cairn Terrier', 'Australian Terrier', 'Dandie Dinmont Terrier',
+    'Boston Terrier', 'Miniature Schnauzer', 'Giant Schnauzer', 'Standard Schnauzer', 'Scottish Terrier',
+    'Tibetan Terrier', 'Australian Silky Terrier', 'Soft-coated Wheaten Terrier', 'West Highland White Terrier',
+    'Lhasa Apso', 'Flat-Coated Retriever', 'Curly-coated Retriever', 'Golden Retriever', 'Labrador Retriever',
+    'Chesapeake Bay Retriever', 'German Shorthaired Pointer', 'Vizsla', 'English Setter', 'Irish Setter',
+    'Gordon Setter', 'Brittany dog', 'Clumber Spaniel', 'English Springer Spaniel', 'Welsh Springer Spaniel',
+    'Cocker Spaniel', 'Sussex Spaniel', 'Irish Water Spaniel', 'Kuvasz', 'Schipperke', 'Groenendael dog', 'Malinois',
+    'Briard', 'Australian Kelpie', 'Komondor', 'Old English Sheepdog', 'Shetland Sheepdog', 'collie', 'Border Collie',
+    'Bouvier des Flandres dog', 'Rottweiler', 'German Shepherd Dog', 'Dobermann', 'Miniature Pinscher',
+    'Greater Swiss Mountain Dog', 'Bernese Mountain Dog', 'Appenzeller Sennenhund', 'Entlebucher Sennenhund', 'Boxer',
+    'Bullmastiff', 'Tibetan Mastiff', 'French Bulldog', 'Great Dane', 'St. Bernard', 'husky', 'Alaskan Malamute',
+    'Siberian Husky', 'Dalmatian', 'Affenpinscher', 'Basenji', 'pug', 'Leonberger', 'Newfoundland dog',
+    'Great Pyrenees dog', 'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', 'brussels griffon', 'Pembroke Welsh Corgi',
+    'Cardigan Welsh Corgi', 'Toy Poodle', 'Miniature Poodle', 'Standard Poodle',
+    'Mexican hairless dog (xoloitzcuintli)', 'grey wolf', 'Alaskan tundra wolf', 'red wolf or maned wolf', 'coyote',
+    'dingo', 'dhole', 'African wild dog', 'hyena', 'red fox', 'kit fox', 'Arctic fox', 'grey fox', 'tabby cat',
+    'tiger cat', 'Persian cat', 'Siamese cat', 'Egyptian Mau', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar',
+    'lion', 'tiger', 'cheetah', 'brown bear', 'American black bear', 'polar bear', 'sloth bear', 'mongoose', 'meerkat',
+    'tiger beetle', 'ladybug', 'ground beetle', 'longhorn beetle', 'leaf beetle', 'dung beetle', 'rhinoceros beetle',
+    'weevil', 'fly', 'bee', 'ant', 'grasshopper', 'cricket insect', 'stick insect', 'cockroach', 'praying mantis',
+    'cicada', 'leafhopper', 'lacewing', 'dragonfly', 'damselfly', 'red admiral butterfly', 'ringlet butterfly',
+    'monarch butterfly', 'small white butterfly', 'sulphur butterfly', 'gossamer-winged butterfly', 'starfish',
+    'sea urchin', 'sea cucumber', 'cottontail rabbit', 'hare', 'Angora rabbit', 'hamster', 'porcupine', 'fox squirrel',
+    'marmot', 'beaver', 'guinea pig', 'common sorrel horse', 'zebra', 'pig', 'wild boar', 'warthog', 'hippopotamus',
+    'ox', 'water buffalo', 'bison', 'ram (adult male sheep)', 'bighorn sheep', 'Alpine ibex', 'hartebeest',
+    'impala (antelope)', 'gazelle', 'arabian camel', 'llama', 'weasel', 'mink', 'European polecat',
+    'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo', 'three-toed sloth', 'orangutan', 'gorilla',
+    'chimpanzee', 'gibbon', 'siamang', 'guenon', 'patas monkey', 'baboon', 'macaque', 'langur',
+    'black-and-white colobus', 'proboscis monkey', 'marmoset', 'white-headed capuchin', 'howler monkey', 'titi monkey',
+    "Geoffroy's spider monkey", 'common squirrel monkey', 'ring-tailed lemur', 'indri', 'Asian elephant',
+    'African bush elephant', 'red panda', 'giant panda', 'snoek fish', 'eel', 'silver salmon', 'rock beauty fish',
+    'clownfish', 'sturgeon', 'gar fish', 'lionfish', 'pufferfish', 'abacus', 'abaya', 'academic gown', 'accordion',
+    'acoustic guitar', 'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance', 'amphibious vehicle',
+    'analog clock', 'apiary', 'apron', 'trash can', 'assault rifle', 'backpack', 'bakery', 'balance beam', 'balloon',
+    'ballpoint pen', 'Band-Aid', 'banjo', 'baluster / handrail', 'barbell', 'barber chair', 'barbershop', 'barn',
+    'barometer', 'barrel', 'wheelbarrow', 'baseball', 'basketball', 'bassinet', 'bassoon', 'swimming cap', 'bath towel',
+    'bathtub', 'station wagon', 'lighthouse', 'beaker', 'military hat (bearskin or shako)', 'beer bottle', 'beer glass',
+    'bell tower', 'baby bib', 'tandem bicycle', 'bikini', 'ring binder', 'binoculars', 'birdhouse', 'boathouse',
+    'bobsleigh', 'bolo tie', 'poke bonnet', 'bookcase', 'bookstore', 'bottle cap', 'hunting bow', 'bow tie',
+    'brass memorial plaque', 'bra', 'breakwater', 'breastplate', 'broom', 'bucket', 'buckle', 'bulletproof vest',
+    'high-speed train', 'butcher shop', 'taxicab', 'cauldron', 'candle', 'cannon', 'canoe', 'can opener', 'cardigan',
+    'car mirror', 'carousel', 'tool kit', 'cardboard box / carton', 'car wheel', 'automated teller machine', 'cassette',
+    'cassette player', 'castle', 'catamaran', 'CD player', 'cello', 'mobile phone', 'chain', 'chain-link fence',
+    'chain mail', 'chainsaw', 'storage chest', 'chiffonier', 'bell or wind chime', 'china cabinet',
+    'Christmas stocking', 'church', 'movie theater', 'cleaver', 'cliff dwelling', 'cloak', 'clogs', 'cocktail shaker',
+    'coffee mug', 'coffeemaker', 'spiral or coil', 'combination lock', 'computer keyboard', 'candy store',
+    'container ship', 'convertible', 'corkscrew', 'cornet', 'cowboy boot', 'cowboy hat', 'cradle', 'construction crane',
+    'crash helmet', 'crate', 'infant bed', 'Crock Pot', 'croquet ball', 'crutch', 'cuirass', 'dam', 'desk',
+    'desktop computer', 'rotary dial telephone', 'diaper', 'digital clock', 'digital watch', 'dining table',
+    'dishcloth', 'dishwasher', 'disc brake', 'dock', 'dog sled', 'dome', 'doormat', 'drilling rig', 'drum', 'drumstick',
+    'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar', 'electric locomotive', 'entertainment center',
+    'envelope', 'espresso machine', 'face powder', 'feather boa', 'filing cabinet', 'fireboat', 'fire truck',
+    'fire screen', 'flagpole', 'flute', 'folding chair', 'football helmet', 'forklift', 'fountain', 'fountain pen',
+    'four-poster bed', 'freight car', 'French horn', 'frying pan', 'fur coat', 'garbage truck',
+    'gas mask or respirator', 'gas pump', 'goblet', 'go-kart', 'golf ball', 'golf cart', 'gondola', 'gong', 'gown',
+    'grand piano', 'greenhouse', 'radiator grille', 'grocery store', 'guillotine', 'hair clip', 'hair spray',
+    'half-track', 'hammer', 'hamper', 'hair dryer', 'hand-held computer', 'handkerchief', 'hard disk drive',
+    'harmonica', 'harp', 'combine harvester', 'hatchet', 'holster', 'home theater', 'honeycomb', 'hook', 'hoop skirt',
+    'gymnastic horizontal bar', 'horse-drawn vehicle', 'hourglass', 'iPod', 'clothes iron', 'carved pumpkin', 'jeans',
+    'jeep', 'T-shirt', 'jigsaw puzzle', 'rickshaw', 'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle',
+    'lampshade', 'laptop computer', 'lawn mower', 'lens cap', 'letter opener', 'library', 'lifeboat', 'lighter',
+    'limousine', 'ocean liner', 'lipstick', 'slip-on shoe', 'lotion', 'music speaker', 'loupe magnifying glass',
+    'sawmill', 'magnetic compass', 'messenger bag', 'mailbox', 'tights', 'one-piece bathing suit', 'manhole cover',
+    'maraca', 'marimba', 'mask', 'matchstick', 'maypole', 'maze', 'measuring cup', 'medicine cabinet', 'megalith',
+    'microphone', 'microwave oven', 'military uniform', 'milk can', 'minibus', 'miniskirt', 'minivan', 'missile',
+    'mitten', 'mixing bowl', 'mobile home', 'ford model t', 'modem', 'monastery', 'monitor', 'moped',
+    'mortar and pestle', 'graduation cap', 'mosque', 'mosquito net', 'vespa', 'mountain bike', 'tent', 'computer mouse',
+    'mousetrap', 'moving van', 'muzzle', 'metal nail', 'neck brace', 'necklace', 'baby pacifier', 'notebook computer',
+    'obelisk', 'oboe', 'ocarina', 'odometer', 'oil filter', 'pipe organ', 'oscilloscope', 'overskirt', 'bullock cart',
+    'oxygen mask', 'product packet / packaging', 'paddle', 'paddle wheel', 'padlock', 'paintbrush', 'pajamas', 'palace',
+    'pan flute', 'paper towel', 'parachute', 'parallel bars', 'park bench', 'parking meter', 'railroad car', 'patio',
+    'payphone', 'pedestal', 'pencil case', 'pencil sharpener', 'perfume', 'Petri dish', 'photocopier', 'plectrum',
+    'Pickelhaube', 'picket fence', 'pickup truck', 'pier', 'piggy bank', 'pill bottle', 'pillow', 'ping-pong ball',
+    'pinwheel', 'pirate ship', 'drink pitcher', 'block plane', 'planetarium', 'plastic bag', 'plate rack', 'farm plow',
+    'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho', 'pool table', 'soda bottle', 'plant pot',
+    "potter's wheel", 'power drill', 'prayer rug', 'printer', 'prison', 'missile', 'projector', 'hockey puck',
+    'punching bag', 'purse', 'quill', 'quilt', 'race car', 'racket', 'radiator', 'radio', 'radio telescope',
+    'rain barrel', 'recreational vehicle', 'fishing casting reel', 'reflex camera', 'refrigerator', 'remote control',
+    'restaurant', 'revolver', 'rifle', 'rocking chair', 'rotisserie', 'eraser', 'rugby ball', 'ruler measuring stick',
+    'sneaker', 'safe', 'safety pin', 'salt shaker', 'sandal', 'sarong', 'saxophone', 'scabbard', 'weighing scale',
+    'school bus', 'schooner', 'scoreboard', 'CRT monitor', 'screw', 'screwdriver', 'seat belt', 'sewing machine',
+    'shield', 'shoe store', 'shoji screen / room divider', 'shopping basket', 'shopping cart', 'shovel', 'shower cap',
+    'shower curtain', 'ski', 'balaclava ski mask', 'sleeping bag', 'slide rule', 'sliding door', 'slot machine',
+    'snorkel', 'snowmobile', 'snowplow', 'soap dispenser', 'soccer ball', 'sock', 'solar thermal collector', 'sombrero',
+    'soup bowl', 'keyboard space bar', 'space heater', 'space shuttle', 'spatula', 'motorboat', 'spider web', 'spindle',
+    'sports car', 'spotlight', 'stage', 'steam locomotive', 'through arch bridge', 'steel drum', 'stethoscope', 'scarf',
+    'stone wall', 'stopwatch', 'stove', 'strainer', 'tram', 'stretcher', 'couch', 'stupa', 'submarine', 'suit',
+    'sundial', 'sunglasses', 'sunglasses', 'sunscreen', 'suspension bridge', 'mop', 'sweatshirt',
+    'swim trunks / shorts', 'swing', 'electrical switch', 'syringe', 'table lamp', 'tank', 'tape player', 'teapot',
+    'teddy bear', 'television', 'tennis ball', 'thatched roof', 'front curtain', 'thimble', 'threshing machine',
+    'throne', 'tile roof', 'toaster', 'tobacco shop', 'toilet seat', 'torch', 'totem pole', 'tow truck', 'toy store',
+    'tractor', 'semi-trailer truck', 'tray', 'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch',
+    'trolleybus', 'trombone', 'hot tub', 'turnstile', 'typewriter keyboard', 'umbrella', 'unicycle', 'upright piano',
+    'vacuum cleaner', 'vase', 'vaulted or arched ceiling', 'velvet fabric', 'vending machine', 'vestment', 'viaduct',
+    'violin', 'volleyball', 'waffle iron', 'wall clock', 'wallet', 'wardrobe', 'military aircraft', 'sink',
+    'washing machine', 'water bottle', 'water jug', 'water tower', 'whiskey jug', 'whistle', 'hair wig',
+    'window screen', 'window shade', 'Windsor tie', 'wine bottle', 'airplane wing', 'wok', 'wooden spoon', 'wool',
+    'split-rail fence', 'shipwreck', 'sailboat', 'yurt', 'website', 'comic book', 'crossword', 'traffic or street sign',
+    'traffic light', 'dust jacket', 'menu', 'plate', 'guacamole', 'consomme', 'hot pot', 'trifle', 'ice cream',
+    'popsicle', 'baguette', 'bagel', 'pretzel', 'cheeseburger', 'hot dog', 'mashed potatoes', 'cabbage', 'broccoli',
+    'cauliflower', 'zucchini', 'spaghetti squash', 'acorn squash', 'butternut squash', 'cucumber', 'artichoke',
+    'bell pepper', 'cardoon', 'mushroom', 'Granny Smith apple', 'strawberry', 'orange', 'lemon', 'fig', 'pineapple',
+    'banana', 'jackfruit', 'cherimoya (custard apple)', 'pomegranate', 'hay', 'carbonara', 'chocolate syrup', 'dough',
+    'meatloaf', 'pizza', 'pot pie', 'burrito', 'red wine', 'espresso', 'tea cup', 'eggnog', 'mountain', 'bubble',
+    'cliff', 'coral reef', 'geyser', 'lakeshore', 'promontory', 'sandbar', 'beach', 'valley', 'volcano',
+    'baseball player', 'bridegroom', 'scuba diver', 'rapeseed', 'daisy', "yellow lady's slipper", 'corn', 'acorn',
+    'rose hip', 'horse chestnut seed', 'coral fungus', 'agaric', 'gyromitra', 'stinkhorn mushroom', 'earth star fungus',
+    'hen of the woods mushroom', 'bolete', 'corn cob', 'toilet paper'
+]

+ 247 - 0
datasets/sampler.py

@@ -0,0 +1,247 @@
+# -------------------------------------------------------------------------
+# Written by Jilan Xu
+# -------------------------------------------------------------------------
+
+import torch
+from torch.utils.data.sampler import Sampler
+import linklink as link
+import math
+import numpy as np
+
+
+class DistributedSampler(Sampler):
+    def __init__(self, dataset, world_size=None, rank=None, round_up=True):
+        if world_size is None:
+            world_size = link.get_world_size()
+        if rank is None:
+            rank = link.get_rank()
+        self.dataset = dataset
+        self.world_size = world_size
+        self.rank = rank
+        self.round_up = round_up
+        self.epoch = 0
+
+        self.num_samples = int(
+            math.ceil(len(self.dataset) * 1.0 / self.world_size))
+        if self.round_up:
+            self.total_size = self.num_samples * self.world_size
+            self.length = self.num_samples
+        else:
+            self.total_size = len(self.dataset)
+
+        if self.rank < self.world_size-1:
+            self.length = self.num_samples
+        else:
+            self.length = self.total_size - \
+                (self.world_size-1)*self.num_samples
+
+    def __iter__(self):
+        g = torch.Generator()
+        g.manual_seed(self.epoch)
+        indices = list(torch.randperm(len(self.dataset), generator=g))
+
+        if self.round_up:
+            indices += indices[:(self.total_size - len(indices))]
+        assert len(indices) == self.total_size
+
+        offset = self.num_samples * self.rank
+        indices = indices[offset:offset + self.num_samples]
+        if self.round_up or (not self.round_up and self.rank < self.world_size-1):
+            assert len(indices) == self.num_samples
+
+        return iter(indices)
+
+    def __len__(self):
+        return self.length
+
+    def set_epoch(self, epoch):
+        self.epoch = epoch
+
+
+class DistributedGivenIterationSampler(Sampler):
+    def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=0):
+        if world_size is None:
+            world_size = link.get_world_size()
+        if rank is None:
+            rank = link.get_rank()
+        assert rank < world_size
+        self.dataset = dataset
+        self.total_iter = total_iter
+        self.batch_size = batch_size
+        self.world_size = world_size
+        self.rank = rank
+        self.last_iter = last_iter
+
+        self.total_size = self.total_iter*self.batch_size
+
+        self.indices = self.gen_new_list()
+        self.call = 0
+
+    def __iter__(self):
+        if self.call == 0:
+            self.call = 1
+            return iter(self.indices[self.last_iter*self.batch_size:])
+        else:
+            raise RuntimeError(
+                "this sampler is not designed to be called more than once!!")
+
+    def gen_new_list(self):
+        np.random.seed(0)
+        all_size = self.total_size * self.world_size
+        indices = np.arange(len(self.dataset))
+        indices = indices[:all_size]
+        num_repeat = (all_size-1) // indices.shape[0] + 1
+        indices = np.tile(indices, num_repeat)
+        indices = indices[:all_size]
+
+        np.random.shuffle(indices)
+        beg = self.total_size * self.rank
+        indices = indices[beg:beg+self.total_size]
+
+        assert len(indices) == self.total_size
+
+        return indices
+
+    def __len__(self):
+        # note here we do not take last iter into consideration, since __len__
+        # should only be used for displaying, the correct remaining size is
+        # handled by dataloader
+        return self.total_size
+
+
+class DistributedEpochSampler(Sampler):
+    def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=0):
+        if world_size is None:
+            world_size = link.get_world_size()
+        if rank is None:
+            rank = link.get_rank()
+        assert rank < world_size
+        self.dataset = dataset
+        self.total_iter = total_iter
+        self.batch_size = batch_size
+        self.world_size = world_size
+        self.rank = rank
+        self.last_iter = last_iter
+
+        self.all_size_single = self.total_iter * self.batch_size
+
+        self.indices = self.gen_new_list()
+        self.call = 0
+
+    def __iter__(self):
+        if self.call == 0:
+            self.call = 1
+            return iter(self.indices[self.last_iter*self.batch_size:])
+        else:
+            raise RuntimeError(
+                "this sampler is not designed to be called more than once!!")
+
+    def get_one_epoch_self_part(self):
+        num = len(self.dataset)
+        indices = np.arange(num)
+        extra_indices = np.random.choice(
+            num, self.extra_per_epoch, replace=False)
+        indices = np.concatenate((indices, extra_indices))
+        np.random.shuffle(indices)
+        assert len(indices) % (self.world_size * self.batch_size) == 0
+        num_single = len(indices) // self.world_size
+        return indices[self.rank*num_single:(self.rank+1)*num_single]
+
+    def gen_new_list(self):
+        np.random.seed(0)
+
+        self.all_num = self.total_iter * self.batch_size * self.world_size
+        iter_per_epoch = (len(self.dataset) -
+                          1) // (self.batch_size * self.world_size) + 1
+        self.num_per_epoch = iter_per_epoch * self.batch_size * self.world_size
+        self.extra_per_epoch = self.num_per_epoch - len(self.dataset)
+        repeat = (self.all_num - 1) // self.num_per_epoch + 1
+        indices = []
+        for i in range(repeat):
+            indice = self.get_one_epoch_self_part()
+            indices.append(indice)
+
+        indices = np.concatenate(indices)
+        indices = indices[:self.all_size_single]
+
+        assert len(indices) == self.all_size_single
+
+        return indices
+
+    def __len__(self):
+        return self.all_size_single
+
+class RankedGivenIterationSampler(Sampler):
+    def __init__(self, dataset, total_iter, batch_size, last_iter=0):
+
+        self.dataset = dataset
+        self.total_iter = total_iter
+        self.batch_size = batch_size
+        self.last_iter = last_iter
+
+        self.total_size = self.total_iter*self.batch_size
+        self.cur_size = self.last_iter * self.batch_size
+        # self.indices = self.gen_new_list()
+        self.indices = np.arange(len(self.dataset))
+        self.call = 0
+
+    def indice_generator(self):
+        np.random.shuffle(self.indices)
+        while self.cur_size < self.total_size:
+            #np.random.shuffle(self.indices)
+            remaining_size = self.total_size - self.cur_size
+            indices = self.indices[:remaining_size]
+            self.cur_size += len(indices)
+            for item in indices:
+                yield item
+            
+    def __iter__(self):
+        if self.call == 0:
+            self.call = 1
+            return self.indice_generator()
+        else:
+            raise RuntimeError("this sampler is not designed to be called more than once!!")
+
+    def __len__(self):
+        # note here we do not take last iter into consideration, since __len__
+        # should only be used for displaying, the correct remaining size is
+        # handled by dataloader
+        return self.total_size
+
+sampler_dict = {
+    'distributed': DistributedSampler,
+    'distributed_iteration': DistributedGivenIterationSampler,
+    'distributed_epoch': DistributedEpochSampler,
+    'ranked_iteration': RankedGivenIterationSampler
+}
+
+
+def build_sampler(dataset, cfg_sampler, cfg_dataset):
+    batch_size = cfg_dataset['batch_size']
+    # check step type: iteration or epoch ?
+    if not getattr(cfg_dataset, 'max_iter', False):
+        world_size = link.get_world_size()
+        iter_per_epoch = (len(dataset) - 1) // (batch_size * world_size) + 1
+        if cfg_sampler['type'] == "naive":
+            total_iter = cfg_dataset['max_epoch'] * ((len(dataset) - 1) // batch_size + 1)  #125200
+        else:
+            total_iter = cfg_dataset['max_epoch'] * iter_per_epoch
+
+    else:
+        total_iter = cfg_dataset['max_iter']
+    # initialize sampler kwargs
+    if cfg_sampler['type'] in ['distributed', "naive", "random"]:
+        sampler_kwargs = {'dataset': dataset}
+    else:
+        sampler_kwargs = {
+            'dataset': dataset,
+            'batch_size': batch_size,
+            'total_iter': total_iter,
+            'last_iter': cfg_dataset['last_iter']
+        }
+    cfg_sampler['kwargs'].update(sampler_kwargs)
+    cfg_dataset['max_iter'] = total_iter
+    cfg_dataset.pop('dataset')
+
+    return sampler_dict[cfg_sampler['type']](**cfg_sampler['kwargs'])
+

+ 170 - 0
datasets/tokenizer.py

@@ -0,0 +1,170 @@
+# -------------------------------------------------------------------------
+# MIT License
+#
+# Copyright (c) 2021 OpenAI
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+    return os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bpe_simple_vocab_16e6.txt.gz')
+
+
+@lru_cache()
+def bytes_to_unicode():
+    """Returns list of utf-8 byte and a corresponding list of unicode strings.
+
+    The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+    if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent
+    coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables
+    between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = list(range(ord('!'), ord('~') + 1)) + list(range(ord('¡'), ord('¬') + 1)) + list(range(ord('®'), ord('ÿ') + 1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8 + n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+    """Return set of symbol pairs in a word.
+
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+def basic_clean(text):
+    text = ftfy.fix_text(text)
+    text = html.unescape(html.unescape(text))
+    return text.strip()
+
+
+def whitespace_clean(text):
+    text = re.sub(r'\s+', ' ', text)
+    text = text.strip()
+    return text
+
+
+class SimpleTokenizer(object):
+
+    def __init__(self, bpe_path: str = default_bpe()):
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+        merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
+        merges = merges[1:49152 - 256 - 2 + 1]
+        merges = [tuple(merge.split()) for merge in merges]
+        vocab = list(bytes_to_unicode().values())
+        vocab = vocab + [v + '</w>' for v in vocab]
+        for merge in merges:
+            vocab.append(''.join(merge))
+        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+        self.encoder = dict(zip(vocab, range(len(vocab))))
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.bpe_ranks = dict(zip(merges, range(len(merges))))
+        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+        self.pat = re.compile(
+            r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+            re.IGNORECASE)
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token[:-1]) + (token[-1] + '</w>', )
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token + '</w>'
+
+        while True:
+            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:  # noqa: E722
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+                    new_word.append(first + second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = ' '.join(word)
+        self.cache[token] = word
+        return word
+
+    # def encode(self, text):
+    #     bpe_tokens = []
+    #     text = whitespace_clean(basic_clean(text)).lower()
+    #     for token in re.findall(self.pat, text):
+    #         token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+    #         bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+    #     return bpe_tokens
+    
+    def encode(self, text):
+        bpe_tokens = []
+        text = whitespace_clean(basic_clean(text)).lower()
+        for token in re.findall(self.pat, text):
+            # print(token)
+            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+        return bpe_tokens
+
+
+    def decode(self, tokens):
+        text = ''.join([self.decoder[token] for token in tokens])
+        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace').replace('</w>', ' ')
+        return text

BIN
figs/model.png


+ 628 - 0
main_pretrain.py

@@ -0,0 +1,628 @@
+# -------------------------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE
+#
+# Written by Ze Liu, Zhenda Xie
+# Modified by Jiarui Xu
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+import argparse
+import datetime
+import os
+import os.path as osp
+import time
+from collections import defaultdict
+import subprocess
+import time
+
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from datasets import build_loader, build_text_transform, imagenet_classes
+from mmcv.parallel import MMDistributedDataParallel
+from mmcv.runner import get_dist_info, init_dist, set_random_seed
+from mmcv.utils import collect_env, get_git_hash
+from mmseg.apis import multi_gpu_test
+from models import build_model
+from omegaconf import OmegaConf, read_write
+from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
+from timm.utils import AverageMeter, accuracy
+from utils import (auto_resume_helper, build_dataset_class_tokens, build_optimizer, build_scheduler, data2cuda,
+                   get_config, get_grad_norm, get_logger, load_checkpoint, parse_losses, reduce_tensor, save_checkpoint, momentum_update,
+                   load_checkpoint_stage1, build_dataset_class_lists,cdist_,
+                   )
+
+from ipdb import set_trace
+import numpy as np
+from torch.utils.tensorboard import SummaryWriter
+from transformers import AutoTokenizer, RobertaTokenizer
+from einops import rearrange
+tokenizer_dict = {
+    'Bert': AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False),
+    'TextTransformer': None,
+}
+
+
+try:
+    # noinspection PyUnresolvedReferences
+    from apex import amp
+except ImportError:
+    amp = None
+
+
+def parse_args():
+    parser = argparse.ArgumentParser('GroupViT training and evaluation script')
+    parser.add_argument('--cfg', type=str, required=True, help='path to config file')
+    parser.add_argument('--opts', help="Modify config options by adding 'KEY=VALUE' list. ", default=None, nargs='+')
+
+    # easy config modification
+    parser.add_argument('--batch-size', type=int, help='batch size for single GPU')
+    parser.add_argument('--resume', help='resume from checkpoint')
+    parser.add_argument(
+        '--amp-opt-level',
+        type=str,
+        default='O1',
+        choices=['O0', 'O1', 'O2'],
+        help='mixed precision opt level, if O0, no amp is used')
+    parser.add_argument(
+        '--output', type=str, help='root of output folder, '
+        'the full path is <output>/<model_name>/<tag>')
+    parser.add_argument('--tag', type=str, help='tag of experiment')
+    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
+    parser.add_argument('--wandb', action='store_true', help='Use W&B to log experiments')
+    parser.add_argument('--keep', type=int, help='Maximum checkpoint to keep')
+
+    # distributed training
+    parser.add_argument('--local_rank', type=int, required=False, default=0, help='local rank for DistributedDataParallel')
+    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+    args = parser.parse_args()
+
+    return args
+
+
+def train(cfg):
+    if cfg.wandb and dist.get_rank() == 0:
+        import wandb
+        wandb.init(
+            project='group_vit',
+            name=osp.join(cfg.model_name, cfg.tag),
+            dir=cfg.output,
+            config=OmegaConf.to_container(cfg, resolve=True),
+            resume=cfg.checkpoint.auto_resume)
+    else:
+        wandb = None
+    # waiting wandb init
+    dist.barrier()
+    
+    dataset_train, dataset_val, \
+        data_loader_train, data_loader_val = build_loader(cfg.data)
+    print('Done train/val loader')
+    data_loader_seg = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
+    print('Done seg loader')
+    
+    logger = get_logger()
+    if dist.get_rank() == 0:
+        writer = SummaryWriter(cfg.output)
+    else:
+        writer = None
+
+    logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
+    model = build_model(cfg.model)
+    model.cuda()
+    logger.info(str(model))
+
+    optimizer = build_optimizer(cfg.train, model)
+    if cfg.train.amp_opt_level != 'O0':
+        model, optimizer = amp.initialize(model, optimizer, opt_level=cfg.train.amp_opt_level)
+
+    model = MMDistributedDataParallel(model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=True)
+    model_without_ddp = model.module
+        
+    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    logger.info(f'number of params: {n_parameters}')
+    lr_scheduler = build_scheduler(cfg.train, optimizer, len(data_loader_train))
+
+    ##### load init params from stage 1 here, before auto resuming ######
+    if cfg.checkpoint.stage1_checkpoint:
+        load_checkpoint_stage1(cfg, model_without_ddp)
+
+    if cfg.checkpoint.auto_resume:
+        resume_file = auto_resume_helper(cfg.output)
+        if resume_file:
+            if cfg.checkpoint.resume:
+                logger.warning(f'auto-resume changing resume file from {cfg.checkpoint.resume} to {resume_file}')
+            with read_write(cfg):
+                cfg.checkpoint.resume = resume_file
+            logger.info(f'auto resuming from {resume_file}')
+        else:
+            logger.info(f'no checkpoint found in {cfg.output}, ignoring auto resume')
+
+    max_accuracy = max_miou = 0.0
+    max_metrics = {'max_accuracy': max_accuracy, 'max_miou': max_miou}
+
+    if cfg.checkpoint.resume:
+        max_metrics = load_checkpoint(cfg, model_without_ddp, optimizer, lr_scheduler)
+        max_accuracy, max_miou = max_metrics['max_accuracy'], max_metrics['max_miou']
+    
+    ############# set tokenizer ##############
+    global tokenizer
+    tokenizer = tokenizer_dict[cfg.model.text_encoder.type]
+    tensorbd_logdir = cfg.output + "/logs"
+
+    logger.info('Start training')
+    start_time = time.time()
+    
+    for epoch in range(cfg.train.start_epoch, cfg.train.epochs):
+        ### train model ###
+        loss_train_dict = train_one_epoch(cfg, model, data_loader_train, optimizer, epoch, lr_scheduler, writer)
+        if dist.get_rank() == 0 and (epoch % cfg.checkpoint.save_freq == 0 or epoch == (cfg.train.epochs - 1)):
+            save_checkpoint(cfg, epoch, model_without_ddp, {
+                'max_accuracy': max_accuracy,
+                'max_miou': max_miou
+            }, optimizer, lr_scheduler)
+        dist.barrier()
+        loss_train = loss_train_dict['total_loss']
+        logger.info(f'Avg loss of the network on the {len(dataset_train)} train images: {loss_train:.2f}')
+                
+        # evaluate
+        if (epoch % cfg.evaluate.eval_freq == 0 or epoch == (cfg.train.epochs - 1)):
+            if 'cls' in cfg.evaluate.task:
+                acc1, acc5, loss = validate_cls(cfg, data_loader_val, model)
+                logger.info(f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%')
+                max_metrics['max_accuracy'] = max(max_metrics['max_accuracy'], acc1)
+                # if cfg.evaluate.cls.save_best and dist.get_rank() == 0 and acc1 > max_accuracy:
+                #     save_checkpoint(
+                #         cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_acc1')
+                dist.barrier()
+                max_accuracy = max_metrics['max_accuracy']
+                logger.info(f'Max accuracy: {max_accuracy:.2f}%')
+
+            if 'seg' in cfg.evaluate.task:
+                miou = validate_seg(cfg, data_loader_seg, model, epoch, writer, tokenizer=tokenizer)
+                logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
+                max_metrics['max_miou'] = max(max_metrics['max_miou'], miou)
+                if cfg.evaluate.seg.save_best and dist.get_rank() == 0 and miou > max_miou:
+                    print('ready saving the best iou model')
+                    save_checkpoint(
+                        cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_miou')
+                dist.barrier()
+                max_miou = max_metrics['max_miou']
+                logger.info(f'Max mIoU: {max_miou:.2f}%')
+
+        if wandb is not None:
+            log_stat = {f'epoch/train_{k}': v for k, v in loss_train_dict.items()}
+            log_stat.update({
+                'epoch/val_acc1': acc1,
+                'epoch/val_acc5': acc5,
+                'epoch/val_loss': loss,
+                'epoch/val_miou': miou,
+                'epoch/epoch': epoch,
+                'epoch/n_parameters': n_parameters
+            })
+            wandb.log(log_stat)
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    logger.info('Training time {}'.format(total_time_str))
+    dist.barrier()
+    # writer.flush()
+
+def process_text(text_data):
+    ### we run all the exps with padding=True, meaning padding to the longest caption ###
+    # text_data = tokenizer(text_data, return_tensors='pt', padding=True,
+    #                         truncation=True, max_length=77)
+    
+    ### this is more memory friendly/load balance if we chunk the padding size to max_length ###
+    text_data = tokenizer(text_data, return_tensors='pt', padding='max_length',
+                            truncation=True, max_length=77)
+    text_data = {key: val.cuda() for key, val in text_data.items()}
+    return text_data
+
+                    
+def generate_entity_masks(text_data):
+    text = text_data['input_ids']
+    # [b, L]
+    entity_masks = text.clone()
+    entity_masks[entity_masks != 103] = 0
+    entity_masks[entity_masks == 103] = 1
+    
+    entity_masks  = entity_masks.to(text.device)
+    return entity_masks
+
+def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, writer):
+    logger = get_logger()
+    dist.barrier()
+    model.train()
+    optimizer.zero_grad()
+    if config.wandb and dist.get_rank() == 0:
+        import wandb
+    else:
+        wandb = None
+
+    num_steps = len(data_loader)
+    batch_time = AverageMeter()
+    loss_meter = AverageMeter()
+    norm_meter = AverageMeter()
+    log_vars_meters = defaultdict(AverageMeter)
+
+    start = time.time()
+    end = time.time()
+
+    text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
+    
+    for idx, samples in enumerate(data_loader):        
+        batch_size = config.data.train.batch_size
+        all_images = samples['image'].cuda()
+
+        all_questions = None
+        entity_labels = entity_masks =  None
+        all_answers = None
+        if config.model.text_encoder['type'] in ['DistilBert','Bert','BertMedium','Roberta']:
+            all_texts = process_text(samples['raw_caption'])
+            if config.data.train.use_entity is True:
+                all_questions = process_text(samples['raw_question'])
+                all_answers= process_text(samples['raw_answer'])
+                entity_masks = generate_entity_masks(all_questions)
+
+        elif config.model.text_encoder['type'] not in ['TextTransformer'] and config.data.train.use_entity is True:
+            all_texts = samples['caption'].cuda()
+            all_questions = samples['question'].cuda()
+            all_answers = samples['answer'].cuda()
+        else:
+            all_texts = samples['caption'].cuda()
+        
+        ### for cross-image mask consistency loss ###
+        all_crossimage = samples['cross_image'].cuda() if 'cross_image' in samples and samples['cross_image'] is not None else None
+        question_masks = samples['question_mask'].cuda() if 'question_mask' in samples else None
+        cross_entity = process_text(samples['cross_entity']) if 'cross_entity' in samples and samples['cross_entity'] is not None else None
+
+        ### forward and compute loss ###
+        losses = model(image=all_images, text=all_texts, cross_image=all_crossimage, cross_entity=cross_entity, \
+                        question=all_questions, answer=all_answers, entity_masks=entity_masks, question_masks=question_masks)
+        loss, log_vars = parse_losses(losses)
+        
+        if dist.get_rank() == 0:
+            writer.add_scalar("Total loss", loss, len(data_loader) * epoch + idx)
+            writer.add_scalar("contrastive loss", losses['loss'], len(data_loader) * epoch + idx)
+            if 'entity' in losses:
+                writer.add_scalar("entity loss", losses['entity'], len(data_loader) * epoch + idx)
+            if 'mask' in losses:
+                writer.add_scalar("Mask loss", losses['mask'], len(data_loader) * epoch + idx)
+            writer.add_scalar("lr",  optimizer.param_groups[0]['lr'], len(data_loader) * epoch + idx)
+
+        if config.train.accumulation_steps > 1:
+            loss = loss / config.train.accumulation_steps
+            if config.train.amp_opt_level != 'O0':
+                with amp.scale_loss(loss, optimizer) as scaled_loss:
+                    scaled_loss.backward()
+                if config.train.clip_grad:
+                    grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.train.clip_grad)
+                else:
+                    grad_norm = get_grad_norm(amp.master_params(optimizer))
+            else:
+                loss.backward()
+                if config.train.clip_grad:
+                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad)
+                else:
+                    grad_norm = get_grad_norm(model.parameters())
+            if (idx + 1) % config.train.accumulation_steps == 0:
+                optimizer.step()
+                optimizer.zero_grad()
+                lr_scheduler.step_update(epoch * num_steps + idx)
+            
+            if config.model.use_maskloss:
+                maskloss_coeff = 0.99
+                momentum_update(model.module.img_encoder, model.module.img_encoder_momentum, maskloss_coeff)
+        
+        else:
+            optimizer.zero_grad()
+            if config.train.amp_opt_level != 'O0':
+                with amp.scale_loss(loss, optimizer) as scaled_loss:
+                    scaled_loss.backward()
+                if config.train.clip_grad:
+                    grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.train.clip_grad)
+                else:
+                    grad_norm = get_grad_norm(amp.master_params(optimizer))
+            else:
+                loss.backward()
+                if config.train.clip_grad:
+                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad)
+                else:
+                    grad_norm = get_grad_norm(model.parameters())
+            optimizer.step()
+            lr_scheduler.step_update(epoch * num_steps + idx)
+            if config.model.use_maskloss:
+                maskloss_coeff = 0.99
+                momentum_update(model.module.img_encoder, model.module.img_encoder_momentum, maskloss_coeff)
+
+
+        torch.cuda.synchronize()
+        loss_meter.update(loss.item(), batch_size)
+        for loss_name in log_vars:
+            log_vars_meters[loss_name].update(log_vars[loss_name], batch_size)
+        norm_meter.update(grad_norm)
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if idx % config.print_freq == 0:
+            lr = optimizer.param_groups[0]['lr']
+            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
+            etas = batch_time.avg * (num_steps - idx)
+            log_vars_str = '\t'.join(f'{n} {m.val:.4f} ({m.avg:.4f})' for n, m in log_vars_meters.items())
+            logger.info(f'Train: [{epoch}/{config.train.epochs}][{idx}/{num_steps}]\t'
+                        f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
+                        f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
+                        f'total_loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
+                        f'{log_vars_str}\t'
+                        f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
+                        f'mem {memory_used:.0f}MB')
+            if wandb is not None:
+                log_stat = {f'iter/train_{n}': m.avg for n, m in log_vars_meters.items()}
+                log_stat['iter/train_total_loss'] = loss_meter.avg
+                log_stat['iter/learning_rate'] = lr
+                wandb.log(log_stat)
+
+    epoch_time = time.time() - start
+    logger.info(f'EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}')
+    result_dict = dict(total_loss=loss_meter.avg)
+    for n, m in log_vars_meters.items():
+        result_dict[n] = m.avg
+    dist.barrier()
+    return result_dict
+
+@torch.no_grad()
+def validate_cls(config, data_loader, model):
+    logger = get_logger()
+    dist.barrier()
+    criterion = torch.nn.CrossEntropyLoss()
+    model.eval()
+
+    batch_time = AverageMeter()
+    loss_meter = AverageMeter()
+    acc1_meter = AverageMeter()
+    acc5_meter = AverageMeter()
+    
+    text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
+
+    end = time.time()
+    logger.info('Building zero shot classifier')
+
+    if config.model.text_encoder['type'] in ['DistilBert', 'Bert','BertMedium','Roberta']:
+        text_embedding = model.module.build_text_embedding(
+                build_dataset_class_lists(config.evaluate.cls.template, imagenet_classes), tokenizer, len(imagenet_classes))
+    else:    
+        text_embedding = data2cuda(
+            model.module.build_text_embedding(
+                build_dataset_class_tokens(text_transform, config.evaluate.cls.template, imagenet_classes)))
+        
+    logger.info('Zero shot classifier built')
+    
+    for idx, samples in enumerate(data_loader):
+        all_images = samples['image'].cuda()
+        target = samples['target'].cuda()
+        output = model(image=all_images, text=text_embedding)
+        # measure accuracy and record loss
+        loss = criterion(output, target)
+        acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+        acc1 = reduce_tensor(acc1)
+        acc5 = reduce_tensor(acc5)
+        loss = reduce_tensor(loss)
+
+        loss_meter.update(loss.item(), target.size(0))
+        acc1_meter.update(acc1.item(), target.size(0))
+        acc5_meter.update(acc5.item(), target.size(0))
+
+        # measure elapsed time
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if idx % config.print_freq == 0:
+            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
+            logger.info(f'Test: [{idx}/{len(data_loader)}]\t'
+                        f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
+                        f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
+                        f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
+                        f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
+                        f'Mem {memory_used:.0f}MB')
+    logger.info('Clearing zero shot classifier')
+    torch.cuda.empty_cache()
+    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
+    dist.barrier()
+    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
+
+@torch.no_grad()
+def validate_seg(config, data_loader, model, epoch=0, writer=None, tokenizer=None):
+    logger = get_logger()
+    dist.barrier()
+    model.eval()
+
+    if hasattr(model, 'module'):
+        model_without_ddp = model.module
+    else:
+        model_without_ddp = model
+
+    text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
+    if config.model.text_encoder['type'] in ['DistilBert', 'Bert','BertMedium','Roberta']:
+        seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg, tokenizer)
+    else:
+        seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
+
+    mmddp_model = MMDistributedDataParallel(
+        seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
+    mmddp_model.eval()
+    results = multi_gpu_test(
+        model=mmddp_model,
+        data_loader=data_loader,
+        tmpdir=None,
+        gpu_collect=True,
+        efficient_test=False,
+        pre_eval=True,
+        format_only=False)
+
+    if dist.get_rank() == 0:
+        metric = [data_loader.dataset.evaluate(results, metric='mIoU')]
+    else:
+        metric = [None]
+    dist.broadcast_object_list(metric)
+    miou_result = metric[0]['mIoU'] * 100
+
+    torch.cuda.empty_cache()
+    logger.info(f'Eval Seg mIoU {miou_result:.2f}')
+    if writer is not None and dist.get_rank() == 0:
+        writer.add_scalar("mIoU", miou_result, epoch)
+    dist.barrier()
+    return miou_result
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    import builtins as __builtin__
+    builtin_print = __builtin__.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop('force', False)
+        if is_master or force:
+            builtin_print(*args, **kwargs)
+
+    __builtin__.print = print
+
+def init_distributed_mode(args):
+    # launched with torch.distributed.launch
+    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ['WORLD_SIZE'])
+        args.gpu = int(os.environ['LOCAL_RANK'])
+    # launched with submitit on a slurm cluster
+    elif 'SLURM_PROCID' in os.environ:
+        proc_id = int(os.environ['SLURM_PROCID'])
+        ntasks = os.environ['SLURM_NTASKS']
+        node_list = os.environ['SLURM_NODELIST']
+        num_gpus = torch.cuda.device_count()
+        addr = subprocess.getoutput(
+            'scontrol show hostname {} | head -n1'.format(node_list)
+        )
+        master_port = os.environ.get('MASTER_PORT', '29484')
+        os.environ['MASTER_PORT'] = master_port
+        os.environ['MASTER_ADDR'] = addr
+        os.environ['WORLD_SIZE'] = str(ntasks)
+        os.environ['RANK'] = str(proc_id)
+        os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+        os.environ['LOCAL_SIZE'] = str(num_gpus)
+        args.dist_url = 'env://'
+        args.world_size = int(ntasks)
+        args.rank = int(proc_id)
+        args.gpu = int(proc_id % num_gpus)
+        print(f'SLURM MODE: proc_id: {proc_id}, ntasks: {ntasks}, node_list: {node_list}, num_gpus:{num_gpus}, addr:{addr}, master port:{master_port}' )
+        
+    # launched naively with `python main_dino.py`
+    # we manually add MASTER_ADDR and MASTER_PORT to env variables
+    elif torch.cuda.is_available():
+        print('Will run the code on one GPU.')
+        args.rank, args.gpu, args.world_size = 0, 0, 1
+        os.environ['MASTER_ADDR'] = '127.0.0.1'
+        os.environ['MASTER_PORT'] = '29500'
+    else:
+        print('Does not support training without GPU.')
+        sys.exit(1)
+
+    dist.init_process_group(
+        backend="nccl",
+        init_method=args.dist_url,
+        world_size=args.world_size,
+        rank=args.rank,
+    )
+
+    torch.cuda.set_device(args.gpu)
+    print('| distributed init (rank {}): {}'.format(
+        args.rank, args.dist_url), flush=True)
+    dist.barrier()
+    setup_for_distributed(args.rank == 0)
+
+def main():
+    args = parse_args()
+    cfg = get_config(args)
+
+    if cfg.train.amp_opt_level != 'O0':
+        assert amp is not None, 'amp not installed!'
+    '''
+    # start faster ref: https://github.com/open-mmlab/mmdetection/pull/7036
+    mp.set_start_method('fork', force=True)
+    init_dist('pytorch')
+    rank, world_size = get_dist_info()
+    print(f'RANK and WORLD_SIZE in environ: {rank}/{world_size}')
+    dist.barrier()
+    '''
+    init_distributed_mode(args)
+    rank, world_size = args.rank, args.world_size
+
+    set_random_seed(cfg.seed, use_rank_shift=True)
+    cudnn.benchmark = True
+
+    os.makedirs(cfg.output, exist_ok=True)
+    logger = get_logger(cfg)
+
+    # linear scale the learning rate according to total batch size, may not be optimal
+    linear_scaled_lr = cfg.train.base_lr * cfg.data.train.batch_size * world_size / 4096.0
+    linear_scaled_warmup_lr = cfg.train.warmup_lr * cfg.data.train.batch_size * world_size / 4096.0
+    linear_scaled_min_lr = cfg.train.min_lr * cfg.data.train.batch_size * world_size / 4096.0
+
+
+    # gradient accumulation also need to scale the learning rate
+    if cfg.train.accumulation_steps > 1:
+        linear_scaled_lr = linear_scaled_lr * cfg.train.accumulation_steps
+        linear_scaled_warmup_lr = linear_scaled_warmup_lr * cfg.train.accumulation_steps
+        linear_scaled_min_lr = linear_scaled_min_lr * cfg.train.accumulation_steps
+
+    with read_write(cfg):
+        logger.info(f'Scale base_lr from {cfg.train.base_lr} to {linear_scaled_lr}')
+        logger.info(f'Scale warmup_lr from {cfg.train.warmup_lr} to {linear_scaled_warmup_lr}')
+        logger.info(f'Scale min_lr from {cfg.train.min_lr} to {linear_scaled_min_lr}')
+        cfg.train.base_lr = linear_scaled_lr
+        cfg.train.warmup_lr = linear_scaled_warmup_lr
+        cfg.train.min_lr = linear_scaled_min_lr
+
+    if dist.get_rank() == 0:
+        path = os.path.join(cfg.output, 'config.json')
+        OmegaConf.save(cfg, path)
+        logger.info(f'Full config saved to {path}')
+
+    # log env info
+    env_info_dict = collect_env()
+    env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
+    dash_line = '-' * 60 + '\n'
+    logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line)
+
+    logger.info(f'Git hash: {get_git_hash(digits=7)}')
+
+    # print config
+    logger.info(OmegaConf.to_yaml(cfg))
+
+    train(cfg)
+    dist.barrier()
+
+if __name__ == '__main__':
+    main()

+ 273 - 0
main_seg.py

@@ -0,0 +1,273 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# ------------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+import argparse
+import os
+import os.path as osp
+import subprocess
+
+import mmcv
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from datasets import build_text_transform
+from main_group_vit import validate_seg
+from mmcv.image import tensor2imgs
+from mmcv.parallel import MMDistributedDataParallel
+from mmcv.runner import set_random_seed
+from models import build_model
+from omegaconf import OmegaConf, read_write
+
+from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
+from utils import get_config, get_logger, load_checkpoint
+from transformers import AutoTokenizer, RobertaTokenizer
+from ipdb import set_trace
+
+try:
+    # noinspection PyUnresolvedReferences
+    from apex import amp
+except ImportError:
+    amp = None
+    
+tokenizer_dict = {
+    'Bert': AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False),
+    'Roberta': RobertaTokenizer.from_pretrained('/mnt/petrelfs/xujilan/roberta-base/'),
+    'TextTransformer': None,
+}
+
+def parse_args():
+    parser = argparse.ArgumentParser('GroupViT segmentation evaluation and visualization')
+    parser.add_argument(
+        '--cfg',
+        type=str,
+        required=True,
+        help='path to config file',
+    )
+    parser.add_argument(
+        '--opts',
+        help="Modify config options by adding 'KEY VALUE' pairs. ",
+        default=None,
+        nargs='+',
+    )
+
+    parser.add_argument('--resume', help='resume from checkpoint')
+    parser.add_argument(
+        '--output', type=str, help='root of output folder, '
+        'the full path is <output>/<model_name>/<tag>')
+    parser.add_argument('--tag', help='tag of experiment')
+    parser.add_argument(
+        '--vis',
+        help='Specify the visualization mode, '
+        'could be a list, support input, pred, input_seg, input_pred_seg_label, all_groups, first_group, last_group',
+        default=None,
+        nargs='+')
+
+    # distributed training
+    parser.add_argument('--local_rank', type=int, required=False, default=0, help='local rank for DistributedDataParallel')
+
+    args = parser.parse_args()
+
+    return args
+
+
+def inference(cfg):
+    logger = get_logger()
+    data_loader = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
+    dataset = data_loader.dataset
+    print('whether activating visualization: ', cfg.vis)
+    
+    logger.info(f'Evaluating dataset: {dataset}')
+
+    logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
+    model = build_model(cfg.model)
+    model.cuda()
+    logger.info(str(model))
+
+    if cfg.train.amp_opt_level != 'O0':
+        model = amp.initialize(model, None, opt_level=cfg.train.amp_opt_level)
+
+    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    logger.info(f'number of params: {n_parameters}')
+
+    load_checkpoint(cfg, model, None, None)
+    
+    global tokenizer
+    tokenizer = tokenizer_dict[cfg.model.text_encoder.type]
+    if cfg.model.text_encoder.type == 'Roberta':
+        tokenizer = RobertaTokenizer.from_pretrained('/mnt/petrelfs/xujilan/roberta-base/')
+        print('Done switching roberta tokenizer')
+    
+    if 'seg' in cfg.evaluate.task:
+        miou = validate_seg(cfg, data_loader, model, tokenizer=tokenizer)
+        logger.info(f'mIoU of the network on the {len(data_loader.dataset)} test images: {miou:.2f}%')
+    else:
+        logger.info('No segmentation evaluation specified')
+
+    if cfg.vis:
+        vis_seg(cfg, data_loader, model, cfg.vis)
+
+
+@torch.no_grad()
+def vis_seg(config, data_loader, model, vis_modes):
+    dist.barrier()
+    model.eval()
+    
+    if hasattr(model, 'module'):
+        model_without_ddp = model.module
+    else:
+        model_without_ddp = model
+
+    text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
+    if config.model.text_encoder['type'] in ['DistilBert', 'Bert','BertMedium','Roberta']:
+        seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg, tokenizer)
+    else:
+        seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
+
+    mmddp_model = MMDistributedDataParallel(
+        seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
+    mmddp_model.eval()
+    model = mmddp_model.module
+    device = next(model.parameters()).device
+    dataset = data_loader.dataset
+
+    if dist.get_rank() == 0:
+        prog_bar = mmcv.ProgressBar(len(dataset))
+    loader_indices = data_loader.batch_sampler
+    for batch_indices, data in zip(loader_indices, data_loader):
+        with torch.no_grad():
+            result = mmddp_model(return_loss=False, **data)
+
+        img_tensor = data['img'][0]
+        img_metas = data['img_metas'][0].data[0]
+        imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
+        assert len(imgs) == len(img_metas)
+
+        for batch_idx, img, img_meta in zip(batch_indices, imgs, img_metas):
+            h, w, _ = img_meta['img_shape']
+            img_show = img[:h, :w, :]
+
+            ori_h, ori_w = img_meta['ori_shape'][:-1]
+            img_show = mmcv.imresize(img_show, (ori_w, ori_h))
+
+            for vis_mode in vis_modes:
+                out_file = osp.join(config.output, 'vis_imgs', vis_mode, f'{batch_idx:04d}.jpg')
+                model.show_result(img_show, img_tensor.to(device), result, out_file, vis_mode)
+            if dist.get_rank() == 0:
+                batch_size = len(result) * dist.get_world_size()
+                for _ in range(batch_size):
+                    prog_bar.update()    
+
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    import builtins as __builtin__
+    builtin_print = __builtin__.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop('force', False)
+        if is_master or force:
+            builtin_print(*args, **kwargs)
+
+    __builtin__.print = print
+
+def init_distributed_mode(args):
+    # launched with torch.distributed.launch
+    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ['WORLD_SIZE'])
+        args.gpu = int(os.environ['LOCAL_RANK'])
+    # launched with submitit on a slurm cluster
+    elif 'SLURM_PROCID' in os.environ:
+        #args.rank = int(os.environ['SLURM_PROCID'])
+        #args.gpu = args.rank % torch.cuda.device_count()
+        proc_id = int(os.environ['SLURM_PROCID'])
+        ntasks = os.environ['SLURM_NTASKS']
+        node_list = os.environ['SLURM_NODELIST']
+        num_gpus = torch.cuda.device_count()
+        addr = subprocess.getoutput(
+            'scontrol show hostname {} | head -n1'.format(node_list)
+        )
+        master_port = os.environ.get('MASTER_PORT', '29499')
+        os.environ['MASTER_PORT'] = master_port
+        os.environ['MASTER_ADDR'] = addr
+        os.environ['WORLD_SIZE'] = str(ntasks)
+        os.environ['RANK'] = str(proc_id)
+        os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+        os.environ['LOCAL_SIZE'] = str(num_gpus)
+        args.dist_url = 'env://'
+        args.world_size = int(ntasks)
+        args.rank = int(proc_id)
+        args.gpu = int(proc_id % num_gpus)
+        print(f'SLURM MODE: proc_id: {proc_id}, ntasks: {ntasks}, node_list: {node_list}, num_gpus:{num_gpus}, addr:{addr}, master port:{master_port}' )
+        
+    # launched naively with `python main_dino.py`
+    # we manually add MASTER_ADDR and MASTER_PORT to env variables
+    elif torch.cuda.is_available():
+        print('Will run the code on one GPU.')
+        args.rank, args.gpu, args.world_size = 0, 0, 1
+        os.environ['MASTER_ADDR'] = '127.0.0.1'
+        os.environ['MASTER_PORT'] = '29500'
+    else:
+        print('Does not support training without GPU.')
+        sys.exit(1)
+
+    dist.init_process_group(
+        backend="nccl",
+        init_method=args.dist_url,
+        world_size=args.world_size,
+        rank=args.rank,
+    )
+
+    torch.cuda.set_device(args.gpu)
+    print('| distributed init (rank {}): {}'.format(
+        args.rank, args.dist_url), flush=True)
+    dist.barrier()
+    setup_for_distributed(args.rank == 0)
+
+def main():
+    args = parse_args()
+    cfg = get_config(args)
+
+    if cfg.train.amp_opt_level != 'O0':
+        assert amp is not None, 'amp not installed!'
+
+    with read_write(cfg):
+        cfg.evaluate.eval_only = True
+    
+    init_distributed_mode(args)
+    rank, world_size = args.rank, args.world_size
+
+    set_random_seed(cfg.seed, use_rank_shift=True)
+    cudnn.benchmark = True
+
+    os.makedirs(cfg.output, exist_ok=True)
+    logger = get_logger(cfg)
+
+    if dist.get_rank() == 0:
+        path = os.path.join(cfg.output, 'config.json')
+        OmegaConf.save(cfg, path)
+        logger.info(f'Full config saved to {path}')
+
+    # print config
+    logger.info(OmegaConf.to_yaml(cfg))
+
+    inference(cfg)
+    dist.barrier()
+
+
+if __name__ == '__main__':
+    main()

+ 23 - 0
models/__init__.py

@@ -0,0 +1,23 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Written by Jilan Xu
+# -------------------------------------------------------------------------
+
+from .builder import build_model
+from .group_vit import GroupViT
+from .multi_label_contrastive import MultiLabelContrastive
+from .transformer import TextTransformer
+from .transformer import DistilBert, Bert, BertMedium, Roberta
+
+__all__ = ['build_model', 'MultiLabelContrastive', 'GroupViT', 'TextTransformer', \
+           'DistilBert', 'Bert', 'BertMedium','Roberta',]

+ 24 - 0
models/builder.py

@@ -0,0 +1,24 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from mmcv.utils import Registry
+from omegaconf import OmegaConf
+
+MODELS = Registry('model')
+
+
+def build_model(config):
+
+    model = MODELS.build(OmegaConf.to_container(config, resolve=True))
+
+    return model

+ 537 - 0
models/clipmodel.py

@@ -0,0 +1,537 @@
+from collections import OrderedDict
+from typing import Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from ipdb import set_trace
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1):
+        super().__init__()
+
+        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+
+        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+
+        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = None
+        self.stride = stride
+
+        if stride > 1 or inplanes != planes * Bottleneck.expansion:
+            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+            self.downsample = nn.Sequential(OrderedDict([
+                ("-1", nn.AvgPool2d(stride)),
+                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+                ("1", nn.BatchNorm2d(planes * self.expansion))
+            ]))
+
+    def forward(self, x: torch.Tensor):
+        identity = x
+
+        out = self.relu(self.bn1(self.conv1(x)))
+        out = self.relu(self.bn2(self.conv2(out)))
+        out = self.avgpool(out)
+        out = self.bn3(self.conv3(out))
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+        return out
+
+
+class AttentionPool2d(nn.Module):
+    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+        super().__init__()
+        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+        self.k_proj = nn.Linear(embed_dim, embed_dim)
+        self.q_proj = nn.Linear(embed_dim, embed_dim)
+        self.v_proj = nn.Linear(embed_dim, embed_dim)
+        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+        self.num_heads = num_heads
+
+    def forward(self, x):
+        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC
+        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
+        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
+        x, _ = F.multi_head_attention_forward(
+            query=x, key=x, value=x,
+            embed_dim_to_check=x.shape[-1],
+            num_heads=self.num_heads,
+            q_proj_weight=self.q_proj.weight,
+            k_proj_weight=self.k_proj.weight,
+            v_proj_weight=self.v_proj.weight,
+            in_proj_weight=None,
+            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+            bias_k=None,
+            bias_v=None,
+            add_zero_attn=False,
+            dropout_p=0,
+            out_proj_weight=self.c_proj.weight,
+            out_proj_bias=self.c_proj.bias,
+            use_separate_proj_weight=True,
+            training=self.training,
+            need_weights=False
+        )
+
+        return x[0]
+
+
+class ModifiedResNet(nn.Module):
+    """
+    A ResNet class that is similar to torchvision's but contains the following changes:
+    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+    - The final pooling layer is a QKV attention instead of an average pool
+    """
+
+    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
+        super().__init__()
+        self.output_dim = output_dim
+        self.input_resolution = input_resolution
+
+        # the 3-layer stem
+        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(width // 2)
+        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(width // 2)
+        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(width)
+        self.avgpool = nn.AvgPool2d(2)
+        self.relu = nn.ReLU(inplace=True)
+
+        # residual layers
+        self._inplanes = width  # this is a *mutable* variable used during construction
+        self.layer1 = self._make_layer(width, layers[0])
+        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+        embed_dim = width * 32  # the ResNet feature dimension
+        self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
+
+    def _make_layer(self, planes, blocks, stride=1):
+        layers = [Bottleneck(self._inplanes, planes, stride)]
+
+        self._inplanes = planes * Bottleneck.expansion
+        for _ in range(1, blocks):
+            layers.append(Bottleneck(self._inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        def stem(x):
+            for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
+                x = self.relu(bn(conv(x)))
+            x = self.avgpool(x)
+            return x
+
+        x = x.type(self.conv1.weight.dtype)
+        x = stem(x)
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+        x = self.attnpool(x)
+
+        return x
+
+
+class LayerNorm(nn.LayerNorm):
+    """Subclass torch's LayerNorm to handle fp16."""
+
+    def forward(self, x: torch.Tensor):
+        orig_type = x.dtype
+        ret = super().forward(x.type(torch.float32))
+        return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+    def forward(self, x: torch.Tensor):
+        return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+        super().__init__()
+
+        self.attn = nn.MultiheadAttention(d_model, n_head)
+        self.ln_1 = LayerNorm(d_model)
+        self.mlp = nn.Sequential(OrderedDict([
+            ("c_fc", nn.Linear(d_model, d_model * 4)),
+            ("gelu", QuickGELU()),
+            ("c_proj", nn.Linear(d_model * 4, d_model))
+        ]))
+        self.ln_2 = LayerNorm(d_model)
+        self.attn_mask = attn_mask
+
+    def attention(self, x: torch.Tensor):
+        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+    def forward(self, x: torch.Tensor):
+        x = x + self.attention(self.ln_1(x))
+        x = x + self.mlp(self.ln_2(x))
+        return x
+
+
+class Transformer(nn.Module):
+    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
+        super().__init__()
+        self.width = width
+        self.layers = layers
+        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+
+    def forward(self, x: torch.Tensor):
+        return self.resblocks(x)
+
+
+class VisionTransformer(nn.Module):
+    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.output_dim = output_dim
+        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+        scale = width ** -0.5
+        self.class_embedding = nn.Parameter(scale * torch.randn(width))
+        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
+        self.ln_pre = LayerNorm(width)
+
+        self.transformer = Transformer(width, layers, heads)
+
+        self.ln_post = LayerNorm(width)
+        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+    ### return all features ####
+    def forward_features(self, x: torch.Tensor, pos_embedding: torch.Tensor):
+        x = self.conv1(x)  # shape = [*, width, grid, grid]
+        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+
+        ### training time or zero-shot eval time, keep class embedding (when pos and x have different token shapes
+        # if pos_embedding.ndim == 2 and x.size(1) != pos_embedding.size(0) or \
+        #     pos_embedding.ndim == 3 and x.size(1) != pos_embedding.size(1):
+        #     x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+        #     # pass
+        
+        # x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+
+        # set_trace()
+        #### Note that we maintain the pos_embed size to match the input
+        # x = x + self.positional_embedding.to(x.dtype)
+        x = x + pos_embedding.to(x.dtype)
+        x = self.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+
+        # x = x[:,0,:]
+        ### added ###
+        x = self.ln_post(x)
+        if self.proj is not None:
+            x = x @ self.proj
+        return x
+
+    
+    def forward_features_with_clstoken(self, x: torch.Tensor, pos_embedding: torch.Tensor):
+        x = self.conv1(x)  # shape = [*, width, grid, grid]
+        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+        x = x + self.positional_embedding.to(x.dtype)
+        x = self.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+
+        cls_token = self.ln_post(x[:, :1, :])  # cls token 
+        img_token = self.ln_post(x[:, 1:, :]) # img token
+        
+        ### without projection to image-text space ###
+        return cls_token, img_token 
+
+    
+    def forward_features_with_prompts(self, x: torch.Tensor, prompt_token: torch.Tensor, pos_embedding: torch.Tensor):
+        x = self.conv1(x)  # shape = [*, width, grid, grid]
+        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+
+        x = x + pos_embedding.to(x.dtype)
+        # [1, hw, d] || [1, 16, d] == [1, hw+16, d]
+        x = torch.cat((prompt_token, x), dim=1)
+        
+        x = self.ln_pre(x)
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+
+        # x = x[:,0,:]
+        ### added ###
+        x = self.ln_post(x)
+        if self.proj is not None:
+            x = x @ self.proj
+        return x
+
+
+    ### return cls token ####
+    def forward_multiscale_features(self, x: torch.Tensor, pos_embedding: torch.Tensor):
+        x = self.conv1(x)  # shape = [*, width, grid, grid]
+        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+
+        x = x + pos_embedding.to(x.dtype)
+        x = self.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+
+        all_x = []
+        for i in range(12):
+            x = self.transformer.resblocks[i]
+            if (i and i % 3 == 0) or i == 11: # 3, 6, 9
+                cur_x = x.permute(1, 0, 2)
+                cur_x = self.ln_post(cur_x)
+                cur_x = cur_x @ self.proj 
+                all_x.append(cur_x)
+        
+        # x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+
+        # x = x[:,0,:]
+        ### added ###
+        x = self.ln_post(x)
+        if self.proj is not None:
+            x = x @ self.proj
+        return x
+
+
+    ### return cls token ####
+    def forward(self, x: torch.Tensor):
+        x = self.conv1(x)  # shape = [*, width, grid, grid]
+        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+        x = x + self.positional_embedding.to(x.dtype)
+        x = self.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+
+        x = self.ln_post(x[:, 0, :])
+
+        if self.proj is not None:
+            x = x @ self.proj
+
+        return x
+
+
+class CLIP(nn.Module):
+    def __init__(self,
+                 embed_dim: int,
+                 # vision
+                 image_resolution: int,
+                 vision_layers: Union[Tuple[int, int, int, int], int],
+                 vision_width: int,
+                 vision_patch_size: int,
+                 # text
+                 context_length: int,
+                 vocab_size: int,
+                 transformer_width: int,
+                 transformer_heads: int,
+                 transformer_layers: int
+                 ):
+        super().__init__()
+
+        self.context_length = context_length
+
+        if isinstance(vision_layers, (tuple, list)):
+            vision_heads = vision_width * 32 // 64
+            self.visual = ModifiedResNet(
+                layers=vision_layers,
+                output_dim=embed_dim,
+                heads=vision_heads,
+                input_resolution=image_resolution,
+                width=vision_width
+            )
+        else:
+            vision_heads = vision_width // 64
+            self.visual = VisionTransformer(
+                input_resolution=image_resolution,
+                patch_size=vision_patch_size,
+                width=vision_width,
+                layers=vision_layers,
+                heads=vision_heads,
+                output_dim=embed_dim
+            )
+
+        self.transformer = Transformer(
+            width=transformer_width,
+            layers=transformer_layers,
+            heads=transformer_heads,
+            attn_mask=self.build_attention_mask()
+        )
+
+        self.vocab_size = vocab_size
+        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
+        self.ln_final = LayerNorm(transformer_width)
+
+        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
+        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+        self.initialize_parameters()
+
+    def initialize_parameters(self):
+        nn.init.normal_(self.token_embedding.weight, std=0.02)
+        nn.init.normal_(self.positional_embedding, std=0.01)
+
+        if isinstance(self.visual, ModifiedResNet):
+            if self.visual.attnpool is not None:
+                std = self.visual.attnpool.c_proj.in_features ** -0.5
+                nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
+                nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
+                nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
+                nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
+
+            for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
+                for name, param in resnet_block.named_parameters():
+                    if name.endswith("bn3.weight"):
+                        nn.init.zeros_(param)
+
+        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+        attn_std = self.transformer.width ** -0.5
+        fc_std = (2 * self.transformer.width) ** -0.5
+        for block in self.transformer.resblocks:
+            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+        if self.text_projection is not None:
+            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+    def build_attention_mask(self):
+        # lazily create causal attention mask, with full attention between the vision tokens
+        # pytorch uses additive attention mask; fill with -inf
+        mask = torch.empty(self.context_length, self.context_length)
+        mask.fill_(float("-inf"))
+        mask.triu_(1)  # zero out the lower diagonal
+        return mask
+
+    @property
+    def dtype(self):
+        return self.visual.conv1.weight.dtype
+
+    def encode_image(self, image):
+        return self.visual(image.type(self.dtype))
+
+    def encode_text(self, text):
+        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
+
+        x = x + self.positional_embedding.type(self.dtype)
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.ln_final(x).type(self.dtype)
+
+        # x.shape = [batch_size, n_ctx, transformer.width]
+        # take features from the eot embedding (eot_token is the highest number in each sequence)
+        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+        return x
+
+    def forward(self, image, text):
+        image_features = self.encode_image(image)
+        text_features = self.encode_text(text)
+
+        # normalized features
+        image_features = image_features / image_features.norm(dim=1, keepdim=True)
+        text_features = text_features / text_features.norm(dim=1, keepdim=True)
+
+        # cosine similarity as logits
+        logit_scale = self.logit_scale.exp()
+        logits_per_image = logit_scale * image_features @ text_features.t()
+        logits_per_text = logits_per_image.t()
+
+        # shape = [global_batch_size, global_batch_size]
+        return logits_per_image, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+    """Convert applicable model parameters to fp16"""
+
+    def _convert_weights_to_fp16(l):
+        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+            l.weight.data = l.weight.data.half()
+            if l.bias is not None:
+                l.bias.data = l.bias.data.half()
+
+        if isinstance(l, nn.MultiheadAttention):
+            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+                tensor = getattr(l, attr)
+                if tensor is not None:
+                    tensor.data = tensor.data.half()
+
+        for name in ["text_projection", "proj"]:
+            if hasattr(l, name):
+                attr = getattr(l, name)
+                if attr is not None:
+                    attr.data = attr.data.half()
+
+    model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict):
+    vit = "visual.proj" in state_dict
+
+    if vit:
+        vision_width = state_dict["visual.conv1.weight"].shape[0]
+        vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+        image_resolution = vision_patch_size * grid_size
+    else:
+        counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+        vision_layers = tuple(counts)
+        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+        vision_patch_size = None
+        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+        image_resolution = output_width * 32
+
+    embed_dim = state_dict["text_projection"].shape[1]
+    context_length = state_dict["positional_embedding"].shape[0]
+    vocab_size = state_dict["token_embedding.weight"].shape[0]
+    transformer_width = state_dict["ln_final.weight"].shape[0]
+    transformer_heads = transformer_width // 64
+    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+
+    model = CLIP(
+        embed_dim,
+        image_resolution, vision_layers, vision_width, vision_patch_size,
+        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
+    )
+
+    for key in ["input_resolution", "context_length", "vocab_size"]:
+        if key in state_dict:
+            del state_dict[key]
+
+    convert_weights(model)
+    model.load_state_dict(state_dict)
+    return model.eval()

+ 1014 - 0
models/group_vit.py

@@ -0,0 +1,1014 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from einops import rearrange
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from .builder import MODELS
+from .misc import Result, interpolate_pos_encoding
+from ipdb import set_trace
+import clip
+import cv2
+
+
+class Mlp(nn.Module):
+
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class MixerMlp(Mlp):
+
+    def forward(self, x):
+        return super().forward(x.transpose(1, 2)).transpose(1, 2)
+
+
+def hard_softmax(logits, dim):
+    y_soft = logits.softmax(dim)
+    # Straight through.
+    index = y_soft.max(dim, keepdim=True)[1]
+    y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
+    ret = y_hard - y_soft.detach() + y_soft
+
+    return ret
+
+
+def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> torch.Tensor:
+    # _gumbels = (-torch.empty_like(
+    #     logits,
+    #     memory_format=torch.legacy_contiguous_format).exponential_().log()
+    #             )  # ~Gumbel(0,1)
+    # more stable https://github.com/pytorch/pytorch/issues/41663
+    gumbel_dist = torch.distributions.gumbel.Gumbel(
+        torch.tensor(0., device=logits.device, dtype=logits.dtype),
+        torch.tensor(1., device=logits.device, dtype=logits.dtype))
+    gumbels = gumbel_dist.sample(logits.shape)
+
+    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
+    y_soft = gumbels.softmax(dim)
+
+    if hard:
+        # Straight through.
+        index = y_soft.max(dim, keepdim=True)[1]
+        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
+        ret = y_hard - y_soft.detach() + y_soft
+    else:
+        # Reparametrization trick.
+        ret = y_soft
+    return ret
+
+
+class AssignAttention(nn.Module):
+
+    def __init__(self,
+                 dim,
+                 num_heads=1,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 attn_drop=0.,
+                 proj_drop=0.,
+                 hard=True,
+                 gumbel=False,
+                 gumbel_tau=1.,
+                 sum_assign=False,
+                 assign_eps=1.):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim**-0.5
+
+        self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
+        self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
+        self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+        self.hard = hard
+        self.gumbel = gumbel
+        self.gumbel_tau = gumbel_tau
+        self.sum_assign = sum_assign
+        self.assign_eps = assign_eps
+
+    def get_attn(self, attn, gumbel=None, hard=None):
+
+        if gumbel is None:
+            gumbel = self.gumbel
+
+        if hard is None:
+            hard = self.hard
+
+        attn_dim = -2
+        if gumbel and self.training:
+            attn = gumbel_softmax(attn, dim=attn_dim, hard=hard, tau=self.gumbel_tau)
+        else:
+            if hard:
+                attn = hard_softmax(attn, dim=attn_dim)
+            else:
+                attn = F.softmax(attn, dim=attn_dim)
+
+        return attn
+        
+    def forward(self, query, key=None, *, value=None, return_attn=False, mask=None):
+        B, N, C = query.shape
+        if key is None:
+            key = query
+        if value is None:
+            value = key
+        S = key.size(1)
+        # [B, nh, N, C//nh]
+        q = rearrange(self.q_proj(query), 'b n (h c)-> b h n c', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
+        # [B, nh, S, C//nh]
+        k = rearrange(self.k_proj(key), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
+        # [B, nh, S, C//nh]
+        v = rearrange(self.v_proj(value), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
+
+        # [B, nh, N, S]
+        raw_attn = (q @ k.transpose(-2, -1)) * self.scale
+
+        attn = self.get_attn(raw_attn)
+        if return_attn:
+            hard_attn = attn.clone()
+            soft_attn = self.get_attn(raw_attn, gumbel=False, hard=False)
+            attn_dict = {'hard': hard_attn, 'soft': soft_attn, 'rawk': key, 'rawq':query, 'k':k, 'q':q}
+        else:
+            attn_dict = None
+
+        if not self.sum_assign:
+            attn = attn / (attn.sum(dim=-1, keepdim=True) + self.assign_eps)
+        attn = self.attn_drop(attn)
+        assert attn.shape == (B, self.num_heads, N, S)
+
+        # [B, nh, N, C//nh] <- [B, nh, N, S] @ [B, nh, S, C//nh]
+        out = rearrange(attn @ v, 'b h n c -> b n (h c)', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
+
+        out = self.proj(out)
+        out = self.proj_drop(out)
+        return out, attn_dict
+
+    def extra_repr(self):
+        return f'num_heads: {self.num_heads}, \n' \
+               f'hard: {self.hard}, \n' \
+               f'gumbel: {self.gumbel}, \n' \
+               f'sum_assign={self.sum_assign}, \n' \
+               f'gumbel_tau: {self.gumbel_tau}, \n' \
+               f'assign_eps: {self.assign_eps}'
+
+
+class GroupingBlock(nn.Module):
+    """Grouping Block to group similar segments together.
+
+    Args:
+        dim (int): Dimension of the input.
+        out_dim (int): Dimension of the output.
+        num_heads (int): Number of heads in the grouping attention.
+        num_output_group (int): Number of output groups.
+        norm_layer (nn.Module): Normalization layer to use.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+        hard (bool): Whether to use hard or soft assignment. Default: True
+        gumbel (bool): Whether to use gumbel softmax. Default: True
+        sum_assign (bool): Whether to sum assignment or average. Default: False
+        assign_eps (float): Epsilon to avoid divide by zero. Default: 1
+        gum_tau (float): Temperature for gumbel softmax. Default: 1
+    """
+
+    def __init__(self,
+                 *,
+                 dim,
+                 out_dim,
+                 num_heads,
+                 num_group_token,
+                 num_output_group,
+                 norm_layer,
+                 mlp_ratio=(0.5, 4.0),
+                 hard=True,
+                 gumbel=True,
+                 sum_assign=False,
+                 assign_eps=1.,
+                 gumbel_tau=1.,
+                 attn_drop=0.,
+                 ):
+        super(GroupingBlock, self).__init__()
+        self.dim = dim
+        self.hard = hard
+        self.gumbel = gumbel
+        self.sum_assign = sum_assign
+        self.num_output_group = num_output_group
+        # norm on group_tokens
+        self.norm_tokens = norm_layer(dim)
+        tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
+        self.mlp_inter = Mlp(num_group_token, tokens_dim, num_output_group)
+        self.norm_post_tokens = norm_layer(dim)
+        # norm on x
+        self.norm_x = norm_layer(dim)
+        self.pre_assign_attn = CrossAttnBlock(
+            dim=dim, num_heads=num_heads, mlp_ratio=4, qkv_bias=True, norm_layer=norm_layer, post_norm=True)
+
+        self.assign = AssignAttention(
+            dim=dim,
+            num_heads=1,
+            qkv_bias=True,
+            hard=hard,
+            gumbel=gumbel,
+            gumbel_tau=gumbel_tau,
+            sum_assign=sum_assign,
+            assign_eps=assign_eps,
+            attn_drop=attn_drop,
+            )
+        self.norm_new_x = norm_layer(dim)
+        self.mlp_channels = Mlp(dim, channels_dim, out_dim)
+        if out_dim is not None and dim != out_dim:
+            self.reduction = nn.Sequential(norm_layer(dim), nn.Linear(dim, out_dim, bias=False))
+        else:
+            self.reduction = nn.Identity()
+
+    def extra_repr(self):
+        return f'hard={self.hard}, \n' \
+               f'gumbel={self.gumbel}, \n' \
+               f'sum_assign={self.sum_assign}, \n' \
+               f'num_output_group={self.num_output_group}, \n '
+
+    def project_group_token(self, group_tokens):
+        """
+        Args:
+            group_tokens (torch.Tensor): group tokens, [B, S_1, C]
+
+        inter_weight (torch.Tensor): [B, S_2, S_1], S_2 is the new number of
+            group tokens, it's already softmaxed along dim=-1
+
+        Returns:
+            projected_group_tokens (torch.Tensor): [B, S_2, C]
+        """
+        # [B, S_2, C] <- [B, S_1, C]
+        projected_group_tokens = self.mlp_inter(group_tokens.transpose(1, 2)).transpose(1, 2)
+        projected_group_tokens = self.norm_post_tokens(projected_group_tokens)
+        return projected_group_tokens
+
+    def forward(self, x, group_tokens, return_attn=False):
+        """
+        Args:
+            x (torch.Tensor): image tokens, [B, L, C]
+            group_tokens (torch.Tensor): group tokens, [B, S_1, C]
+            return_attn (bool): whether to return attention map
+
+        Returns:
+            new_x (torch.Tensor): [B, S_2, C], S_2 is the new number of
+                group tokens
+        """
+        group_tokens = self.norm_tokens(group_tokens)
+        x = self.norm_x(x)
+        # [B, S_2, C]
+        projected_group_tokens = self.project_group_token(group_tokens)
+        projected_group_tokens = self.pre_assign_attn(projected_group_tokens, x)
+        new_x, attn_dict = self.assign(projected_group_tokens, x, return_attn=return_attn)
+        new_x += projected_group_tokens
+
+        new_x = self.reduction(new_x) + self.mlp_channels(self.norm_new_x(new_x))
+
+        return new_x, attn_dict
+
+
+class Attention(nn.Module):
+
+    def __init__(self,
+                 dim,
+                 num_heads,
+                 out_dim=None,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 attn_drop=0.,
+                 proj_drop=0.,
+                 qkv_fuse=False):
+        super().__init__()
+        if out_dim is None:
+            out_dim = dim
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim**-0.5
+        self.qkv_fuse = qkv_fuse
+
+        if qkv_fuse:
+            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        else:
+            self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
+            self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
+            self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, out_dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def extra_repr(self):
+        return f'num_heads={self.num_heads}, \n' \
+               f'qkv_bias={self.scale}, \n' \
+               f'qkv_fuse={self.qkv_fuse}'
+
+    def forward(self, query, key=None, *, value=None, mask=None):
+        if self.qkv_fuse:
+            assert key is None
+            assert value is None
+            x = query
+            B, N, C = x.shape
+            S = N
+            # [3, B, nh, N, C//nh]
+            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+            # [B, nh, N, C//nh]
+            q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+        else:
+            B, N, C = query.shape
+            if key is None:
+                key = query
+            if value is None:
+                value = key
+            S = key.size(1)
+            # [B, nh, N, C//nh]
+            q = rearrange(self.q_proj(query), 'b n (h c)-> b h n c', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
+            # [B, nh, S, C//nh]
+            k = rearrange(self.k_proj(key), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
+            # [B, nh, S, C//nh]
+            v = rearrange(self.v_proj(value), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
+
+        # [B, nh, N, S]
+        attn = (q @ k.transpose(-2, -1)) * self.scale
+        attn = attn.softmax(dim=-1)
+        
+        attn = self.attn_drop(attn)
+        assert attn.shape == (B, self.num_heads, N, S)
+
+        # [B, nh, N, C//nh] -> [B, N, C]
+        # out = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        out = rearrange(attn @ v, 'b h n c -> b n (h c)', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
+        out = self.proj(out)
+        out = self.proj_drop(out)
+        return out
+
+
+class CrossAttnBlock(nn.Module):
+
+    def __init__(self,
+                 dim,
+                 num_heads,
+                 mlp_ratio=4.,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 act_layer=nn.GELU,
+                 norm_layer=nn.LayerNorm,
+                 post_norm=False):
+        super().__init__()
+        if post_norm:
+            self.norm_post = norm_layer(dim)
+            self.norm_q = nn.Identity()
+            self.norm_k = nn.Identity()
+        else:
+            self.norm_q = norm_layer(dim)
+            self.norm_k = norm_layer(dim)
+            self.norm_post = nn.Identity()
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+    def forward(self, query, key, *, mask=None):
+        x = query
+        x = x + self.drop_path(self.attn(self.norm_q(query), self.norm_k(key), mask=mask))
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        x = self.norm_post(x)
+        return x
+
+
+class AttnBlock(nn.Module):
+
+    def __init__(self,
+                 dim,
+                 num_heads,
+                 mlp_ratio=4.,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 act_layer=nn.GELU,
+                 norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim,
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            qk_scale=qk_scale,
+            attn_drop=attn_drop,
+            proj_drop=drop,
+            qkv_fuse=True)
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+    def forward(self, x, mask=None):
+        x = x + self.drop_path(self.attn(self.norm1(x), mask=mask))
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+
+class GroupingLayer(nn.Module):
+    """A Transformer layer with Grouping Block for one stage.
+
+    Args:
+        dim (int): Number of input channels.
+        num_input_token (int): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer.
+            In GroupViT setting, Grouping Block serves as the downsampling layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+        group_projector (nn.Module | None, optional): Projector for the grouping layer. Default: None.
+        zero_init_group_token (bool): Whether to initialize the grouping token to 0. Default: False.
+    """
+
+    def __init__(self,
+                 dim,
+                 num_input_token,
+                 depth,
+                 num_heads,
+                 num_group_token,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 norm_layer=nn.LayerNorm,
+                 downsample=None,
+                 use_checkpoint=False,
+                 group_projector=None,
+                 zero_init_group_token=False,
+                 ):
+
+        super().__init__()
+        self.dim = dim
+        self.input_length = num_input_token
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+        self.num_group_token = num_group_token
+        if num_group_token > 0:
+            self.group_token = nn.Parameter(torch.zeros(1, num_group_token, dim))
+            if not zero_init_group_token:
+                trunc_normal_(self.group_token, std=.02)
+        else:
+            self.group_token = None
+
+        # build blocks
+        self.depth = depth
+        blocks = []
+        for i in range(depth):
+            blocks.append(
+                AttnBlock(
+                    dim=dim,
+                    num_heads=num_heads,
+                    mlp_ratio=mlp_ratio,
+                    qkv_bias=qkv_bias,
+                    qk_scale=qk_scale,
+                    drop=drop,
+                    attn_drop=attn_drop,
+                    drop_path=drop_path[i],
+                    norm_layer=norm_layer))
+        self.blocks = nn.ModuleList(blocks)
+
+        self.downsample = downsample
+        self.input_resolution = num_input_token
+        self.use_checkpoint = use_checkpoint
+
+        self.group_projector = group_projector
+
+    @property
+    def with_group_token(self):
+        return self.group_token is not None
+
+    def extra_repr(self):
+        return f'dim={self.dim}, \n' \
+               f'input_resolution={self.input_resolution}, \n' \
+               f'depth={self.depth}, \n' \
+               f'num_group_token={self.num_group_token}, \n'
+
+    def split_x(self, x):
+        if self.with_group_token:
+            return x[:, :-self.num_group_token], x[:, -self.num_group_token:]
+        else:
+            return x, None
+
+    def concat_x(self, x, group_token=None):
+        if group_token is None:
+            return x
+        return torch.cat([x, group_token], dim=1)
+
+    def forward(self, x, prev_group_token=None, return_attn=False):
+        """
+        Args:
+            x (torch.Tensor): image tokens, [B, L, C]
+            prev_group_token (torch.Tensor): group tokens, [B, S_1, C]
+            return_attn (bool): whether to return attention maps
+        """
+        if self.with_group_token:
+            group_token = self.group_token.expand(x.size(0), -1, -1)
+            if self.group_projector is not None:
+                group_token = group_token + self.group_projector(prev_group_token)
+        else:
+            group_token = None
+
+        B, L, C = x.shape
+        cat_x = self.concat_x(x, group_token)
+
+        for blk_idx, blk in enumerate(self.blocks):
+            if self.use_checkpoint:
+                cat_x = checkpoint.checkpoint(blk, cat_x)
+            else:
+                cat_x = blk(cat_x)
+
+        x, group_token = self.split_x(cat_x)
+
+        attn_dict = None
+        if self.downsample is not None:
+            x, attn_dict = self.downsample(x, group_token, return_attn=return_attn)
+        return x, group_token, attn_dict
+
+
+class PatchEmbed(nn.Module):
+    """Image to Patch Embedding."""
+
+    def __init__(self, img_size=224, kernel_size=7, stride=4, padding=2, in_chans=3, embed_dim=96, norm_layer=None):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        kernel_size = to_2tuple(kernel_size)
+        stride = to_2tuple(stride)
+        padding = to_2tuple(padding)
+        self.img_size = img_size
+        self.patches_resolution = (
+            int((img_size[1] + 2 * padding[1] - kernel_size[1]) / stride[1] + 1),
+            int((img_size[0] + 2 * padding[0] - kernel_size[0]) / stride[0] + 1),
+        )
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
+        if norm_layer is not None:
+            self.norm = norm_layer(embed_dim)
+        else:
+            self.norm = None
+
+    @property
+    def num_patches(self):
+        return self.patches_resolution[1] * self.patches_resolution[0]
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        if self.training:
+            # FIXME look at relaxing size constraints
+            assert H == self.img_size[0] and W == self.img_size[1], \
+                f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x)
+        hw_shape = x.shape[2:]
+        x = x.flatten(2).transpose(1, 2)
+        if self.norm is not None:
+            x = self.norm(x)
+        return x, hw_shape
+
+
+@MODELS.register_module()
+class GroupViT(nn.Module):
+    r""" Group Vision Transformer
+        A PyTorch impl of : `GroupViT: Semantic Segmentation Emerges from Text Supervision`  -
+          https://arxiv.org/pdf/2202.11094.pdf
+
+    Args:
+        img_size (int | tuple[int]): Input image size. Default 224
+        patch_size (int | tuple[int]): Patch size. Default: 4
+        in_chans (int): Number of input image channels. Default: 3
+        num_classes (int): Number of classes for classification head. Default: 0
+        embed_dim (int): Patch embedding dimension. Default: 384
+        embed_factors (list[int]): Embedding dim multipliers for each stage.
+        depths (list[int]): Depth of each stage
+        num_heads (list[int]): Number of heads for each stage
+        num_group_tokens (list[int]): Number of group tokens for each stage
+        num_output_group (list[int]): Number of output groups for each stage
+        hard_assignment (bool): Whether to use hard assignment or not. Default: True
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+        drop_rate (float): Dropout rate. Default: 0
+        attn_drop_rate (float): Attention dropout rate. Default: 0
+        drop_path_rate (float): Stochastic depth rate. Default: 0.1
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+        pos_embed_type (str): Type of positional embedding. Default: 'simple'
+        freeze_patch_embed (bool): Whether to freeze patch embedding. Default: False
+    """
+
+    def __init__(self,
+                 img_size=224,
+                 patch_size=16,
+                 in_chans=3,
+                 num_classes=0,
+                 embed_dim=384,
+                 embed_factors=[1, 1, 1],
+                 depths=[6, 3, 3],
+                 num_heads=[6, 6, 6],
+                 num_group_tokens=[64, 8, 0],
+                 num_output_groups=[64, 8],
+                 hard_assignment=True,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.1,
+                 patch_norm=True,
+                 use_checkpoint=False,
+                 pos_embed_type='simple',
+                 freeze_patch_embed=False,
+                 imgnet_pretrained=None,
+                 fixed=False,
+                 imgnet_pretrained_checkpoint='/mnt/petrelfs/xujilan/checkpoints/dino_vitbase16_pretrain.pth',
+                 ):
+        super().__init__()
+        assert patch_size in [4, 8, 16]
+        self.num_classes = num_classes
+        assert len(embed_factors) == len(depths) == len(num_group_tokens)
+        assert all(_ == 0 for _ in num_heads) or len(depths) == len(num_heads)
+        # assert len(depths) - 1 == len(num_output_groups)
+        self.num_layers = len(depths)
+        self.embed_dim = embed_dim
+        self.patch_norm = patch_norm
+        self.num_features = int(embed_dim * embed_factors[len(depths) - 1])
+        self.mlp_ratio = mlp_ratio
+        self.qkv_bias = qkv_bias
+        self.qk_scale = qk_scale
+        self.drop_rate = drop_rate
+        self.attn_drop_rate = attn_drop_rate
+        self.drop_path_rate = drop_path_rate
+        self.num_group_tokens = num_group_tokens
+        self.num_output_groups = num_output_groups
+        self.pos_embed_type = pos_embed_type
+        assert pos_embed_type in ['simple', 'fourier']
+        self.freeze_backbone = fixed
+
+        norm_layer = nn.LayerNorm
+
+        # split image into non-overlapping patches
+        self.patch_embed = PatchEmbed(
+            img_size=img_size,
+            kernel_size=patch_size,
+            stride=patch_size,
+            padding=0,
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+            norm_layer=norm_layer if self.patch_norm else None)
+        num_patches = self.patch_embed.num_patches
+        patches_resolution = self.patch_embed.patches_resolution
+        self.patches_resolution = patches_resolution
+
+        self.avgpool = nn.AdaptiveAvgPool1d(1)
+
+        if pos_embed_type == 'simple':
+            self.pos_embed = self.build_simple_position_embedding()
+        elif pos_embed_type == 'fourier':
+            self.pos_embed = self.build_2d_sincos_position_embedding()
+        else:
+            raise ValueError
+
+        if freeze_patch_embed:
+            for param in self.patch_embed.parameters():
+                param.requires_grad = False
+            self.pos_embed.requires_grad = False
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        # stochastic depth
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+
+        num_input_token = num_patches
+        num_output_token = num_input_token
+        # build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+
+            dim = int(embed_dim * embed_factors[i_layer])
+            downsample = None
+            
+            if i_layer < self.num_layers - 1:
+                out_dim = embed_dim * embed_factors[i_layer + 1]
+                downsample = GroupingBlock(
+                    dim=dim,
+                    out_dim=out_dim,
+                    num_heads=num_heads[i_layer],
+                    num_group_token=num_group_tokens[i_layer],
+                    num_output_group=num_output_groups[i_layer],
+                    norm_layer=norm_layer,
+                    hard=hard_assignment,
+                    gumbel=hard_assignment,
+                    attn_drop=attn_drop_rate,
+                    )
+                num_output_token = num_output_groups[i_layer]
+            
+            
+            if i_layer > 0 and num_group_tokens[i_layer] > 0:
+                prev_dim = int(embed_dim * embed_factors[i_layer - 1])
+                group_projector = nn.Sequential(
+                    norm_layer(prev_dim),
+                    MixerMlp(num_group_tokens[i_layer - 1], prev_dim // 2, num_group_tokens[i_layer]))
+
+                if dim != prev_dim:
+                    group_projector = nn.Sequential(group_projector, norm_layer(prev_dim),
+                                                    nn.Linear(prev_dim, dim, bias=False))
+            else:
+                group_projector = None
+            layer = GroupingLayer(
+                dim=dim,
+                num_input_token=num_input_token,
+                depth=depths[i_layer],
+                num_heads=num_heads[i_layer],
+                num_group_token=num_group_tokens[i_layer],
+                mlp_ratio=self.mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+                norm_layer=norm_layer,
+                downsample=downsample,
+                use_checkpoint=use_checkpoint,
+                group_projector=group_projector,
+                # only zero init group token if we have a projection
+                zero_init_group_token=group_projector is not None,
+                )
+            self.layers.append(layer)
+            if i_layer < self.num_layers - 1:
+                num_input_token = num_output_token
+
+        self.norm = norm_layer(self.num_features)
+        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+        self.apply(self._init_weights)
+        self.imgnet_pretrained = imgnet_pretrained
+        self.proj = None
+
+        if imgnet_pretrained is not None:
+            ### add cls_token to enable params loading ###
+            self.pos_embed = self.build_simple_position_embedding_with_cls_token()
+            self.init_backbone_with_imagenet_weights(imgnet_pretrained_checkpoint)
+            ### drop cls_token ###
+            self.pos_embed = nn.Parameter(self.pos_embed[0, 1:])
+            
+    def init_backbone_with_imagenet_weights(self, checkpoint_path):
+        if self.imgnet_pretrained == 'imgnet':
+            from timm.models import vit_base_patch16_224
+            net = vit_base_patch16_224(pretrained=True)
+            state_dict = net.state_dict()
+        elif self.imgnet_pretrained in ['dino', 'dinob8', 'dinos16', 'dinos8']:
+            state_dict = torch.load(checkpoint_path)
+        elif self.imgnet_pretrained == 'clip':
+            clip_model, _ = clip.load('ViT-B/16', device='cuda', jit=False)
+            state_dict = clip_model.visual.state_dict()
+        
+
+        print('Initializing ImageNet-pretrained weights')
+        print('$' * 100)
+        newdict = {}
+        if self.imgnet_pretrained != 'clip':
+            if self.num_layers == 2:
+                for kk, vv in state_dict.items():
+                    newkey = kk
+                    if kk.startswith('blocks.'):
+                        layerid = int(kk.split('.')[1])
+                        if 0 <= layerid < 6:
+                            newkey = 'layers.0.' + kk
+                        elif 6 <= layerid < 12:
+                            old_prefix = 'blocks.' + str(layerid) + '.'
+                            new_prefix = 'blocks.' + str(layerid - 6) + '.'
+                            suffix = kk.split(old_prefix)[1]
+                            newkey = 'layers.1.' + new_prefix + suffix
+                    newdict[newkey] = vv
+            elif self.num_layers == 3:
+                for kk, vv in state_dict.items():
+                    newkey = kk
+                    if kk.startswith('blocks.'):
+                        layerid = int(kk.split('.')[1])
+                        if 0 <= layerid < 6:
+                            newkey = 'layers.0.' + kk
+                        elif 6 <= layerid < 9:
+                            old_prefix = 'blocks.' + str(layerid) + '.'
+                            new_prefix = 'blocks.' + str(layerid - 6) + '.'
+                            suffix = kk.split(old_prefix)[1]
+                            newkey = 'layers.1.' + new_prefix + suffix
+                        elif 9 <= layerid < 12:
+                            old_prefix = 'blocks.' + str(layerid) + '.'
+                            new_prefix = 'blocks.' + str(layerid - 9) + '.'
+                            suffix = kk.split(old_prefix)[1]
+                            newkey = 'layers.2.' + new_prefix + suffix
+                    newdict[newkey] = vv
+        else:
+            for kk, vv in state_dict.items():
+                newkey = kk
+                newkey = newkey.replace('transformer.','')
+                newkey = newkey.replace('resblocks', 'blocks')
+                
+                newkey = newkey.replace('attn.in_proj_weight','attn.qkv.weight')
+                newkey = newkey.replace('attn.in_proj_bias','attn.qkv.bias')
+                newkey = newkey.replace('attn.out_proj.weight','attn.proj.weight')
+                newkey = newkey.replace('attn.out_proj.bias','attn.proj.bias')
+                
+                newkey = newkey.replace('ln_1.weight','norm1.weight')
+                newkey = newkey.replace('ln_1.bias','norm1.bias')
+                newkey = newkey.replace('ln_2.weight','norm2.weight')
+                newkey = newkey.replace('ln_2.bias','norm2.bias')
+                
+                newkey = newkey.replace('mlp.c_fc.weight','mlp.fc1.weight')
+                newkey = newkey.replace('mlp.c_fc.bias', 'mlp.fc1.bias')
+                newkey = newkey.replace('mlp.c_proj.weight','mlp.fc2.weight')
+                newkey = newkey.replace('mlp.c_proj.bias', 'mlp.fc2.bias')
+                
+                newkey = newkey.replace('ln_post.weight', 'norm.weight')
+                newkey = newkey.replace('ln_post.bias', 'norm.bias')
+                
+                newkey = newkey.replace('positional_embedding', 'pos_embed')
+                newkey = newkey.replace('conv1.weight', 'patch_embed.proj.weight')
+                
+                kk = newkey
+                if newkey == 'proj':
+                    self.proj = nn.Parameter(torch.zeros(vv.shape[0], vv.shape[1]))
+
+                if newkey == 'pos_embed':
+                    vv = vv.unsqueeze(0)
+                if kk.startswith('blocks.'):
+                    layerid = int(kk.split('.')[1])
+                    if 0 <= layerid < 6:
+                        newkey = 'layers.0.' + kk
+                    elif 6 <= layerid < 12:
+                        old_prefix = 'blocks.' + str(layerid) + '.'
+                        new_prefix = 'blocks.' + str(layerid - 6) + '.'
+                        suffix = kk.split(old_prefix)[1]
+                        newkey = 'layers.1.' + new_prefix + suffix
+                newdict[newkey] = vv
+                
+        ### init all self-attn/pos_embed/patch_embed layers ###
+        msg = self.load_state_dict(newdict, strict=False)
+        if self.freeze_backbone:
+            for n, p in self.named_parameters():
+                if n in newdict:
+                    p.requires_grad = False
+                    print('Freezing parameter: ', n)
+
+        print(msg)
+        print('$' * 100)
+        
+    def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
+        if self.pos_embed_type == 'simple' and 'pos_embed' in state_dict:
+            load_pos_embed = state_dict['pos_embed']
+            pos_embed = self.pos_embed
+            if load_pos_embed.shape != pos_embed.shape:
+                H_new = int(self.patch_embed.num_patches**0.5)
+                W_new = H_new
+                H_ori = int(load_pos_embed.shape[1]**0.5)
+                W_ori = H_ori
+                load_pos_embed = F.interpolate(
+                    rearrange(load_pos_embed, 'b (h w) c -> b c h w', h=H_ori, w=W_ori, b=1),
+                    size=(H_new, W_new),
+                    mode='bicubic',
+                    align_corners=False)
+                load_pos_embed = rearrange(load_pos_embed, 'b c h w -> b (h w) c', h=H_new, w=W_new)
+                state_dict['pos_embed'] = load_pos_embed
+        return super().load_state_dict(state_dict, strict)
+
+    def build_simple_position_embedding(self):
+        pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches, self.embed_dim))
+        trunc_normal_(pos_embed, std=.02)
+        return pos_embed
+
+    def build_simple_position_embedding_with_cls_token(self):
+        pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim))
+        trunc_normal_(pos_embed, std=.02)
+        return pos_embed
+
+    def build_2d_sincos_position_embedding(self, temperature=10000.):
+        h, w = self.patch_embed.patches_resolution
+        grid_w = torch.arange(w, dtype=torch.float32)
+        grid_h = torch.arange(h, dtype=torch.float32)
+        grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
+        assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
+        pos_dim = self.embed_dim // 4
+        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
+        omega = 1. / (temperature**omega)
+        out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
+        out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
+        pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
+
+        pos_embed = nn.Parameter(pos_emb)
+        pos_embed.requires_grad = False
+        return pos_embed
+
+    @property
+    def width(self):
+        return self.num_features
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def get_pos_embed(self, B, H, W):
+        if self.training:
+            return self.pos_embed
+        pos_embed = self.pos_embed
+        pos_embed = interpolate_pos_encoding(pos_embed, H, W)
+        return pos_embed
+
+    def forward_features(self, x, *, return_attn=False):
+        B = x.shape[0]
+        x, hw_shape = self.patch_embed(x)
+
+        x = x + self.get_pos_embed(B, *hw_shape)
+        x = self.pos_drop(x)
+
+        group_token = None
+        attn_dict_list = []
+        for i, layer in enumerate(self.layers):
+            x, group_token, attn_dict = layer(x, group_token, return_attn=return_attn)
+            if attn_dict is not None:
+                attn_dict_list.append(attn_dict)
+
+        x = self.norm(x)
+        return x, group_token, attn_dict_list
+
+    def forward_image_head(self, x):
+        """
+
+        Args:
+            x: shape [B, L, C]
+
+        Returns:
+
+        """
+        # [B, L, C]
+        x = self.avgpool(x.transpose(1, 2))  # B C 1
+        x = torch.flatten(x, 1)
+        x = self.head(x) if self.proj is None else x @ self.proj
+
+        return x
+
+    def forward(self, x, *, return_feat=False, return_attn=False, as_dict=False, sampled_noun_indices=None):
+        x, group_token, attn_dicts = self.forward_features(x, return_attn=return_attn)
+        x_feat = x if return_feat else None
+
+        outs = Result(as_dict=as_dict)
+        outs.append(self.forward_image_head(x), name='x')
+        if return_feat:
+            outs.append(x_feat if self.proj is None else x_feat @ self.proj, name='feat')
+
+        if return_attn:
+            outs.append(attn_dicts, name='attn_dicts')
+
+        return outs.as_return()

+ 175 - 0
models/losses.py

@@ -0,0 +1,175 @@
+# -------------------------------------------------------------------------
+# Written by Jilan Xu
+# -------------------------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+import numpy as np
+from torch import linalg as LA
+
+from scipy.optimize import linear_sum_assignment
+# from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
+from ipdb import set_trace
+import torch.distributed as dist
+import diffdist.functional as diff_dist
+from ipdb import set_trace
+
+
+def dist_collect(x):
+    """ collect all tensor from all GPUs
+    args:
+        x: shape (mini_batch, ...)
+    returns:
+        shape (mini_batch * num_gpu, ...)
+    """
+    x = x.contiguous()
+    out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
+    out_list = diff_dist.all_gather(out_list, x)
+    return torch.cat(out_list, dim=0).contiguous()
+
+class HungarianMatcher(nn.Module):
+    """This class computes an assignment between the targets and the predictions of the network
+    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+    while the others are un-matched (and thus treated as non-objects).
+    """
+
+    def __init__(self, cost_type='L2'):
+        """Creates the matcher
+        Params:
+            cost_class: This is the relative weight of the classification error in the matching cost
+            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
+            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
+        """
+        super().__init__()
+        self.cost_type = cost_type
+        
+    @torch.no_grad()
+    def forward(self, outputs, targets):
+        """ Performs the matching
+        NewParams:
+            outputs: [b, k, h * w], k normalized masks 
+            targets: [b, k, h * w]  k normalized masks
+            
+        Params:s
+            outputs: This is a dict that contains at least these entries:
+                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
+            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+                           objects in the target) containing the class labels
+                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
+        Returns:
+            A list of size batch_size, containing tuples of (index_i, index_j) where:
+                - index_i is the indices of the selected predictions (in order)
+                - index_j is the indices of the corresponding selected targets (in order)
+            For each batch element, it holds:
+                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+        """
+        bs, num_queries = outputs.shape[:2]
+        # We flatten to compute the cost matrices in a batch
+        # out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
+        # out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]
+        if self.cost_type == 'L2':
+            cost_mask = torch.cdist(outputs, targets, p=2) #[b, k, k]
+        elif self.cost_type == 'cosine':
+            ##### <a, b> / (||a|| * ||b||) ######
+            cos_sim = outputs @ targets.transpose(-2, -1)   #[b, k, k]
+            dist_a = LA.norm(outputs, dim=-1).unsqueeze(-1) #[b, k, 1]
+            dist_b = LA.norm(targets, dim=-1).unsqueeze(-2) #[b, 1, k]
+            eps = 1e-6
+            ### negative cosine similarity as cost matrix
+            cost_mask = -1 * (cos_sim / (dist_a + eps) / (dist_b + eps)) 
+        else:
+            return ValueError
+        # set_trace()
+        inds = []
+        inds2 = []
+        for i in range(bs):
+            xx, yy = linear_sum_assignment(cost_mask[i].cpu())
+            inds.append(xx)
+            inds2.append(yy)
+        # indices = [linear_sum_assignment(cost_mask[i]) for i in range(bs)]
+        # indices = [linear_sum_assignment(c[i].cpu()) for i, c in enumerate(cost_mask.split(bs, -1))]
+        # indices = [linear_sum_assignment(c[i].cpu()) for i, c in zip(range(bs), cost_mask)]
+        inds = torch.tensor(inds).long().cuda()
+        inds2 = torch.tensor(inds2).long().cuda()
+        return inds, inds2
+        # indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
+        # return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
+    
+def dice_loss(inputs, targets, num_masks=None, threshold=0.0, topk_mask=None):
+    """
+    Compute the DICE loss, similar to generalized IOU for masks
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+        
+        1. norm the input and the target to [0, 1] with sigmoid
+        2. binarize the target
+        3. compute dice loss
+    """
+    if num_masks is None:
+        num_masks = inputs.size(1)
+
+    if topk_mask is not None:
+        ### [bs, k, nm] * [bs, k, 1], filter the masked clusters
+        inputs = inputs * topk_mask.unsqueeze(-1)
+        targets = targets * topk_mask.unsqueeze(-1) 
+
+    inputs = inputs.flatten(1)
+    targets = targets.flatten(1)
+
+    numerator = 2 * (inputs * targets).sum(-1)
+    denominator = inputs.sum(-1) + targets.sum(-1)
+    loss = 1 - (numerator + 1) / (denominator + 1)
+    return loss.sum() / num_masks
+
+def get_logits(dense_feat_1, selected_feat_2, logit_scale):
+    # logit_scale_dense = self.logit_scale.exp()
+    logit_scale_dense = torch.clamp(logit_scale.exp(), max=100)
+    
+    i, j, k = dense_feat_1.shape
+    l, m, k = selected_feat_2.shape
+    dense_feat_1 = dense_feat_1.reshape(-1, k)
+    selected_feat_2 = selected_feat_2.reshape(-1, k)
+    final_logits_1 = logit_scale_dense * dense_feat_1 @ selected_feat_2.t()
+    final_logits_1 = final_logits_1.reshape(i, j, l, m).permute(0,2,1,3)
+    return final_logits_1
+
+
+def sim_matrix(a, b, eps=1e-8):
+    """
+    added eps for numerical stability
+    """
+
+    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
+    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
+    b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
+    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
+    return sim_mt
+
+class NormSoftmaxLoss(nn.Module):
+
+    def __init__(self, temperature=0.05):
+        super().__init__()
+
+        self.temperature = temperature
+
+    def forward(self, x):
+        i_logsm = F.log_softmax(x/self.temperature, dim=1)
+        j_logsm = F.log_softmax(x.t()/self.temperature, dim=1)
+
+        # sum over positives
+        idiag = torch.diag(i_logsm)
+        loss_i = idiag.sum() / len(idiag)
+
+        jdiag = torch.diag(j_logsm)
+        loss_j = jdiag.sum() / len(jdiag)
+
+        return - loss_i - loss_j

+ 80 - 0
models/misc.py

@@ -0,0 +1,80 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+import math
+
+import torch.nn.functional as F
+from ipdb import set_trace
+
+class Result:
+
+    def __init__(self, as_dict=False):
+        if as_dict:
+            self.outs = {}
+        else:
+            self.outs = []
+
+    @property
+    def as_dict(self):
+        return isinstance(self.outs, dict)
+
+    def append(self, element, name=None):
+        if self.as_dict:
+            assert name is not None
+            self.outs[name] = element
+        else:
+            self.outs.append(element)
+
+    def update(self, **kwargs):
+        if self.as_dict:
+            self.outs.update(**kwargs)
+        else:
+            for v in kwargs.values():
+                self.outs.append(v)
+
+    def as_output(self):
+        if self.as_dict:
+            return self.outs
+        else:
+            return tuple(self.outs)
+
+    def as_return(self):
+        outs = self.as_output()
+        if self.as_dict:
+            return outs
+        if len(outs) == 1:
+            return outs[0]
+        return outs
+
+def interpolate_pos_encoding(pos_embed, H, W):
+    num_patches = H * W
+
+    ##### problems might occur here######, 
+    ##### N = pos.embed.shape[0] and num_patches could be N - 1
+    if pos_embed.ndim == 2:
+        pos_embed = pos_embed.unsqueeze(0)
+    N = pos_embed.shape[1]
+    # N = pos_embed.shape[1]
+    if num_patches == N and W == H:
+        return pos_embed
+
+    patch_pos_embed = pos_embed
+    dim = pos_embed.shape[-1]
+    patch_pos_embed = F.interpolate(
+        patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+        size=(H, W),
+        mode='bicubic',
+        align_corners=False)
+    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+    return patch_pos_embed

+ 811 - 0
models/multi_label_contrastive.py

@@ -0,0 +1,811 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+import diffdist.functional as diff_dist
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from timm.loss import SoftTargetCrossEntropy
+from random import choice
+
+from .builder import MODELS
+from .misc import Result
+from .losses import HungarianMatcher, dice_loss
+
+from ipdb import set_trace
+import torchvision.ops.roi_pool as roi_pool
+import cv2
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+from .group_vit import CrossAttnBlock, AssignAttention, AttnBlock
+
+def dist_collect(x):
+    """ collect all tensor from all GPUs
+    args:
+        x: shape (mini_batch, ...)
+    returns:
+        shape (mini_batch * num_gpu, ...)
+    """
+    x = x.contiguous()
+    out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
+    out_list = diff_dist.all_gather(out_list, x)
+    return torch.cat(out_list, dim=0).contiguous()
+
+class ProjectMLP(nn.Module):
+    def __init__(self, in_dim=256, inner_dim=4096, out_dim=256, num_layers=2):
+        super(ProjectMLP, self).__init__()
+        # hidden layers
+        linear_hidden = []
+        for i in range(num_layers - 1):
+            linear_hidden.append(nn.Conv1d(in_dim if i == 0 else inner_dim, inner_dim, kernel_size=1))
+            linear_hidden.append(nn.BatchNorm1d(inner_dim))
+            linear_hidden.append(nn.ReLU(inplace=True))
+        self.linear_hidden = nn.Sequential(*linear_hidden)
+
+        self.linear_out = nn.Conv1d(
+            in_dim if num_layers == 1 else inner_dim, out_dim, kernel_size=1) if num_layers >= 1 else nn.Identity()
+
+    def forward(self, x):
+        """
+
+        Args:
+            x (torch.Tensor): output of transformers, shape [B, L, C]
+
+        Returns:
+
+        """
+        assert x.ndim in [2, 3], x.ndim
+        add_dim = False
+        if x.ndim == 2:
+            # [B, C] -> [B, L, C]
+            x = x.unsqueeze(1)
+            add_dim = True
+
+        x = rearrange(x, 'b l c -> b c l')
+        x = self.linear_hidden(x)
+        x = self.linear_out(x)
+        x = rearrange(x, 'b c l -> b l c')
+
+        if add_dim:
+            x = x.squeeze(1)
+
+        return x
+
+class MultimodalGroupingBlock(nn.Module):
+    """Grouping Block to group similar segments together.
+
+    Args:
+        dim (int): Dimension of the input.
+        out_dim (int): Dimension of the output.
+        num_heads (int): Number of heads in the grouping attention.
+        num_output_group (int): Number of output groups.
+        norm_layer (nn.Module): Normalization layer to use.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+        hard (bool): Whether to use hard or soft assignment. Default: True
+        gumbel (bool): Whether to use gumbel softmax. Default: True
+        sum_assign (bool): Whether to sum assignment or average. Default: False
+        assign_eps (float): Epsilon to avoid divide by zero. Default: 1
+        gum_tau (float): Temperature for gumbel softmax. Default: 1
+    """
+
+    def __init__(self,
+                 *,
+                 dim,
+                 out_dim,
+                 num_heads,
+                 norm_layer,
+                 mlp_ratio=(0.5, 4.0),
+                 hard=True,
+                 gumbel=True,
+                 sum_assign=False,
+                 assign_eps=1.,
+                 gumbel_tau=1.,
+                 attn_drop=0.,
+                 ):
+        super(MultimodalGroupingBlock, self).__init__()
+        self.dim = dim
+        self.hard = hard
+        self.gumbel = gumbel
+        self.sum_assign = sum_assign
+        # norm on group_tokens
+        self.norm_tokens = norm_layer(dim)
+        tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
+        # norm on x
+        self.norm_x = norm_layer(dim)
+        # self.visual_attn = AttnBlock(
+        #     dim=dim, num_heads=num_heads, mlp_ratio=4, qkv_bias=True, norm_layer=norm_layer )
+        self.pre_assign_attn = CrossAttnBlock(
+            dim=dim, num_heads=num_heads, mlp_ratio=4, qkv_bias=True, norm_layer=norm_layer, post_norm=True)
+
+        self.post_attn = AttnBlock(
+            dim=dim, num_heads=num_heads, mlp_ratio=4, qkv_bias=True, norm_layer=norm_layer )
+        
+        self.assign = AssignAttention(
+            dim=dim,
+            num_heads=1,
+            qkv_bias=True,
+            hard=hard,
+            gumbel=gumbel,
+            gumbel_tau=gumbel_tau,
+            sum_assign=sum_assign,
+            assign_eps=assign_eps,
+            attn_drop=attn_drop,
+            )
+        self.norm_new_x = norm_layer(dim)
+
+    def forward(self, ans_tokens, visual_tokens, text_tokens, entity_masks=None, question_masks=None, return_attn=False):
+        """
+        Args:
+            x (torch.Tensor): group_tokens, [B, k, C]
+            group_tokens (torch.Tensor): word tokens, [B, L, C]
+            return_attn (bool): whether to return attention map
+
+        Returns:
+            new_x (torch.Tensor): [B, S_2, C], S_2 is the new number of
+                group tokens
+        """
+        # [B, K, C], self-attention
+        # visual_tokens = self.visual_attn(visual_tokens)
+    
+        text_tokens = self.norm_tokens(text_tokens)
+        visual_tokens = self.norm_x(visual_tokens)
+    
+        # [B, L, C], cross attention
+        projected_text_tokens = self.pre_assign_attn(text_tokens, visual_tokens)
+        ### mask needs to be [b, 1, 77, 1] to match [b, nh, 77, k]
+        # projected_text_tokens = text_tokens
+        # new_x, attn_dict = self.assign(projected_text_tokens, visual_tokens, return_attn=return_attn, mask=question_masks)
+        
+        if ans_tokens is None:
+            ans_temp = projected_text_tokens
+        else:
+            ans_temp = ans_tokens + projected_text_tokens    
+    
+        ############## self-attn only ###################
+        if question_masks is not None:
+            new_x = self.post_attn(ans_temp, mask=question_masks)
+        else:
+            new_x = self.post_attn(ans_temp)
+        new_x += projected_text_tokens
+        
+        new_x = self.norm_new_x(new_x)
+        return new_x
+
+
+class MultimodalGroupingNetwork(nn.Module):
+    """Grouping Block to group similar segments together.
+
+    Args:
+        dim (int): Dimension of the input.
+        out_dim (int): Dimension of the output.
+        num_heads (int): Number of heads in the grouping attention.
+        norm_layer (nn.Module): Normalization layer to use.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+        hard (bool): Whether to use hard or soft assignment. Default: True
+        gumbel (bool): Whether to use gumbel softmax. Default: True
+        sum_assign (bool): Whether to sum assignment or average. Default: False
+        assign_eps (float): Epsilon to avoid divide by zero. Default: 1
+        gum_tau (float): Temperature for gumbel softmax. Default: 1
+    """
+
+    def __init__(self,
+                 *,
+                 dim,
+                 out_dim,
+                 num_heads,
+                 norm_layer,
+                 mlp_ratio=(0.5, 4.0),
+                 hard=True,
+                 gumbel=True,
+                 sum_assign=False,
+                 assign_eps=1.,
+                 gumbel_tau=1.,
+                 attn_drop=0.,
+                 num_layers=1,
+                 ):
+        super(MultimodalGroupingNetwork, self).__init__()
+        self.num_layers = num_layers
+        self.blocks = nn.ModuleList([
+                MultimodalGroupingBlock(
+                    dim=dim,
+                    out_dim=out_dim,
+                    num_heads=num_heads,
+                    norm_layer=norm_layer,
+                    mlp_ratio=mlp_ratio,
+                    hard=hard,
+                    gumbel=gumbel,
+                    sum_assign=sum_assign,
+                    assign_eps=assign_eps,
+                    gumbel_tau=gumbel_tau,
+                    attn_drop=attn_drop,
+                ) for i in range(num_layers)
+            ])
+        
+        
+    def forward(self, visual_tokens, text_tokens, entity_masks=None, question_masks=None, return_attn=False, return_feat=False):
+        """
+        Args:
+            x (torch.Tensor): group_tokens, [B, k, C]
+            group_tokens (torch.Tensor): word tokens, [B, L, C]
+            return_attn (bool): whether to return attention map
+
+        Returns:
+            new_x (torch.Tensor): [B, S_2, C], S_2 is the new number of
+                group tokens
+
+            1. norm
+            2. cross-attn
+            3. self-attn
+
+        """
+        ans_text = None
+        for i, blk in enumerate(self.blocks):
+            ans_text = blk(ans_text, visual_tokens, text_tokens, entity_masks, question_masks, return_attn)
+        
+        if return_feat is True:  #[B, L, d_t]
+            return ans_text
+
+        answer = ans_text[:, 0]
+        return answer
+        
+
+@MODELS.register_module()
+class MultiLabelContrastive(nn.Module):
+
+    def __init__(self,
+                 img_encoder,
+                 text_encoder,
+                 output_dim=256,
+                 contrast_temperature=0.07,
+                 proj_num_layers=2,
+                 multi_label=0,
+                 share_temperature=False,
+                 multi_label_loss_weight=1.0,
+                 
+                 use_entityloss=False,
+                 entity_weight=1.0,
+                 cross_layers=1,
+
+                 use_maskloss=False,
+                 maskloss_weight=0.1,
+                 num_deep_stages=1,
+                 cost_type='L2',
+                 cross_threshold=0.6,
+                 topmask_ratio=1.0,
+                 dual_dice=False,
+                 group_ratio=0.5,
+                 ):
+        super().__init__()
+
+        self.img_encoder = MODELS.build(img_encoder)
+        self.text_encoder = MODELS.build(text_encoder)
+        self.img_encoder_type = img_encoder['type']
+        self.text_encoder_type = text_encoder['type']
+        # add 
+        print('self image encoder: ', img_encoder)
+        print('self text encoder:', text_encoder)
+
+        self.contrast_temperature = contrast_temperature
+        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
+        
+        self.cross_entropy = nn.CrossEntropyLoss(ignore_index=-1)
+        self.binary_cross_entropy = nn.BCELoss()
+        self.binary_cross_entropy_with_logits = nn.BCEWithLogitsLoss()
+        self.soft_cross_entropy = SoftTargetCrossEntropy()
+        self.mse_loss = nn.MSELoss()
+
+        
+        self.proj_num_layers = proj_num_layers
+        self.multi_label = multi_label
+        
+        if proj_num_layers > 0:
+        # if proj_num_layers > 0 and self.use_clip_visual is False:
+            self.img_projector = ProjectMLP(
+                in_dim=self.img_encoder.width, num_layers=proj_num_layers, out_dim=output_dim)
+            self.text_projector = ProjectMLP(
+                in_dim=self.text_encoder.width, num_layers=proj_num_layers, out_dim=output_dim)
+            self.img_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.img_projector)
+            self.text_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.text_projector)
+        elif proj_num_layers == -1:
+            self.img_projector = nn.Linear(self.img_encoder.width, self.text_encoder.width)
+            self.text_projector = nn.Identity()
+        else:
+            self.img_projector = nn.Identity()
+            self.text_projector = nn.Identity()
+        
+
+        self.share_temperature = share_temperature
+        if self.with_multi_label and not self.share_temperature:
+            self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
+        self.multi_label_loss_weight = multi_label_loss_weight
+        
+        ### for masked entity loss ###
+        self.use_entityloss = use_entityloss
+        self.entity_weight = entity_weight
+        self.cross_layers = cross_layers
+        if self.use_entityloss:
+            min_width = min(self.img_encoder.width, self.text_encoder.width)
+            max_width = max(self.img_encoder.width, self.text_encoder.width)
+            self.align_proj_img = nn.Linear(max_width, min_width) if self.img_encoder.width > self.text_encoder.width else nn.Identity()
+            self.align_proj_text = nn.Linear(max_width, min_width) if self.text_encoder.width > self.img_encoder.width else nn.Identity()
+            
+            ### similar to transformer decoder ###
+            self.multimodal_groupingblock = MultimodalGroupingNetwork(
+                dim=min_width,
+                out_dim=min_width,
+                num_heads=8,
+                norm_layer=nn.LayerNorm,
+                hard=False,
+                gumbel=False,
+                num_layers=cross_layers,
+            )
+            self.bridge_projector = ProjectMLP(
+                in_dim=min_width, num_layers=proj_num_layers, out_dim=output_dim)
+        
+        
+        ### for mask loss ###
+        self.use_maskloss = use_maskloss
+        self.maskloss_weight = maskloss_weight
+
+        self.cross_threshold = cross_threshold
+        self.topmask_ratio = topmask_ratio
+        self.dual_dice = dual_dice
+        self.group_ratio = group_ratio
+        
+        if self.use_maskloss:
+            self.num_deep_stages = num_deep_stages
+            self.logit_scale_mask = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
+            self.img_encoder_momentum = MODELS.build(img_encoder)
+            
+            self.q_projector = nn.Identity()
+            self.k_projector = nn.Identity()
+            self.q_projector_momentum = nn.Identity()
+            self.k_projector_momentum = nn.Identity()
+
+            ## set momentum branch offline
+            for p in self.img_encoder_momentum.parameters():
+                p.requires_grad = False
+            self.matcher = HungarianMatcher(cost_type=cost_type)
+            
+                    
+    def mask_loss(self, mask1, mask2, threshold, imgtokens=None, text=None, indicator='none'):
+        # set_trace()
+        bs = mask1.size(0)
+        num_masks = mask1.size(1)
+        
+        ################# hungarian matching #######################################
+        #[b, k, hw], make the masks exclusive with softmax???
+        ############# Note, we keep the original mask, while using the normed mask to compute matching ########
+        mask1 = torch.flatten(mask1, 2).float()
+        mask2 = torch.flatten(mask2, 2).float()
+        mask1_norm = F.normalize(mask1, dim=-1)
+        mask2_norm = F.normalize(mask2, dim=-1)
+
+        idx1, idx2 = self.matcher(mask1_norm, mask2_norm)
+        mask1 = mask1[torch.arange(bs).unsqueeze(1), idx1]
+        mask2 = mask2[torch.arange(bs).unsqueeze(1), idx2]
+        
+        ################## norm and contrastive loss ################################
+        #[b, k, hw]
+        
+        ################# BCE loss ##################################################
+        ### hard-thresholding ###
+        def min_max_norm(x):
+            x_max = torch.max(x, dim=-1, keepdim=True)[0]
+            x_min = torch.min(x, dim=-1, keepdim=True)[0]
+            return (x - x_min) / (x_max - x_min)
+        
+        ################ THIS IS PERHAPS IMPORTANT HERE ##############
+        mask2 = mask2.sigmoid()
+        # mask2 = F.softmax(mask2, dim=1)
+        # mask2 = min_max_norm(mask2)
+        # mask2 = F.normalize(mask2)
+
+        mask2_pseudo = mask2
+        mask2_pseudo = rearrange(mask2_pseudo, 'b k d -> (b k) d')
+
+        thres_onehot = torch.max(mask2_pseudo, dim=-1, keepdim=True)[0] * threshold
+        mask2_onehot = mask2_pseudo - thres_onehot
+        mask2_onehot[mask2_onehot >= 0] = 1.0
+        mask2_onehot[mask2_onehot < 0] = 0.0
+        mask2_onehot = rearrange(mask2_onehot, '(b k) d -> b k d', k=num_masks)
+
+        # self.draw_attn(rearrange(mask1, 'b k (h w) -> b k h w', k=num_masks, h=224), 'before_sigmoid')
+        # set_trace()
+        # mask1 = F.softmax(mask1, dim=1)
+        # mask1 = torch.sigmoid(mask1)
+        mask1 = min_max_norm(mask1)
+        
+        ####### select topk mask for contrast w.r.t ratio #######
+        topk_mask = None
+        # if self.topmask_ratio < 1.0:
+        #     alltoken_logits = (imgtokens @ text.unsqueeze(-1)).squeeze(-1) #[bs, k]
+        #     topk_logits = torch.topk(alltoken_logits, k=int(num_masks * self.topmask_ratio))[1]
+        #     topk_mask = torch.zeros_like(alltoken_logits)
+        #     topk_mask[torch.arange(bs).unsqueeze(1), topk_logits] = 1.0
+            # set_trace()
+        #########################################################
+
+        loss = dice_loss(mask1, mask2_onehot, topk_mask=topk_mask) 
+
+        return loss
+
+    @property
+    def with_multi_label(self):
+        return self.multi_label > 0
+
+    def loss(self, image_x, text_x):
+        batch_size = image_x.shape[0]
+        # get label globally
+        labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device) + batch_size * dist.get_rank()
+
+        image_x = F.normalize(image_x, dim=-1) #[B, C]
+        text_x = F.normalize(text_x, dim=-1) #[B, C]
+
+        logits_per_img = image_x @ dist_collect(text_x).t()
+        logits_per_text = text_x @ dist_collect(image_x).t()
+        logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
+        loss_img = self.cross_entropy(logits_per_img * logit_scale, labels)
+        loss_text = self.cross_entropy(logits_per_text * logit_scale, labels)
+
+        loss = 0.5 * (loss_img + loss_text)
+        return loss
+
+    def multi_label_loss(self, image_feat, text_feat):
+        """
+
+        Args:
+            image_feat (torch.Tensor): shape [B, L1, C]
+            text_feat (torch.Tensor): shape [B, L2, C]
+
+        Returns:
+
+        """
+        # [B, L1, C], L1 = 1
+        image_feat = F.normalize(image_feat, dim=-1)
+        # [B, L2, C]
+        text_feat = F.normalize(text_feat, dim=-1)
+
+        # [B, L1, L2]
+        dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')
+        # [B, L2, L1]
+        dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')
+
+        if self.share_temperature:
+            logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
+        else:
+            logit_scale = torch.clamp(self.multi_label_logit_scale.exp(), max=100)
+
+        batch = image_feat.shape[0]
+        img_len = image_feat.shape[1]
+        text_len = text_feat.shape[1]
+        # [B, L1, L2]
+        pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')
+        # [B, L2, L1]
+        pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')
+
+        image_x = rearrange(image_feat, 'b l c -> (b l) c')
+        text_x = rearrange(text_feat, 'b l c -> (b l) c')
+
+        logits_per_img = image_x @ dist_collect(text_x).t()
+        logits_per_text = text_x @ dist_collect(image_x).t()
+
+        # get label globally
+        # [B, L1, B, L2, W]
+        labels_per_img = F.one_hot(
+            torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * dist.get_rank(),
+            num_classes=dist.get_world_size()).to(image_x.dtype)
+        labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(
+            torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')
+        # [BxL1, WxBxL2]
+        labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')
+        # [B, L2, B, L1, W]
+        labels_per_text = F.one_hot(
+            torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * dist.get_rank(),
+            num_classes=dist.get_world_size()).to(text_x.dtype)
+        labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(
+            torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')
+        # [BxL2, WxBxL1]
+        labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')
+
+        loss_img = self.soft_cross_entropy(logits_per_img * logit_scale, labels_per_img)
+        loss_text = self.soft_cross_entropy(logits_per_text * logit_scale, labels_per_text)
+
+        loss = 0.5 * (loss_img + loss_text)
+
+        return loss
+
+    def encode_image(self, image, *, return_feat=False, as_dict=False, return_attn=False, momentum=False):
+        outs = Result(as_dict)
+        ### momentum branch, no gradient update ###
+        if momentum:
+            with torch.no_grad():
+                img_outs = self.img_encoder_momentum(image, return_feat=return_feat, as_dict=True, return_attn=return_attn)
+                outs.append(self.img_projector(img_outs['x']), 'image_x')
+                if return_feat and 'feat' in img_outs:
+                    outs.append(img_outs['x'], 'image_x_before_proj')
+                    outs.append(img_outs['feat'], 'image_feat_before_proj')
+                
+                if return_feat:
+                    outs.append(self.img_projector(img_outs['feat']), 'image_feat')
+                if return_attn:
+                    outs.append(img_outs['attn_dicts'], 'attn_dicts')
+                return outs.as_return()            
+        else:
+        ### online branch ###
+            img_outs = self.img_encoder(image, return_feat=return_feat, as_dict=True, return_attn=return_attn)
+            # change here
+            outs.append(self.img_projector(img_outs['x']), 'image_x')
+            if return_feat and 'feat' in img_outs:
+                outs.append(img_outs['x'], 'image_x_before_proj')
+                outs.append(img_outs['feat'], 'image_feat_before_proj')
+
+            if return_feat:
+                outs.append(self.img_projector(img_outs['feat']), 'image_feat')
+            if return_attn:
+                outs.append(img_outs['attn_dicts'], 'attn_dicts')
+            return outs.as_return()
+    
+    def encode_text(self, text, *, as_dict=False, forward_template=False):
+        # assert text.ndim in [2, 3], text.ndim
+        squeeze_dim = False
+        num_text = 1
+        if type(text) is not dict and text.ndim == 3:
+            num_text = text.shape[1]
+            text = rearrange(text, 'b n l -> (b n) l', n=num_text)
+            squeeze_dim = True
+
+        outs = Result(as_dict=as_dict)
+        # [B, C]
+        text_outs = self.text_encoder(text)
+        if 'all_tokens' in text_outs:
+            all_tokens = text_outs['all_tokens'].contiguous()
+
+        x = text_outs['x']
+        text_x = self.text_projector(x)
+        
+        outs.append(text_x, 'text_x')
+        outs.append(x, 'text_x_before_proj') # add transformer out
+        outs.append(all_tokens, 'text_feat_before_proj')
+        outs.append(self.text_projector(all_tokens), 'text_feat_after_proj')
+
+        # if squeeze_dim:
+        if (squeeze_dim or self.with_multi_label) and self.training:
+            # text_x = rearrange(text_x, '(b n) c -> b n c', n=num_text)
+            text_x = rearrange(text_x, '(b n) c -> b n c', n=self.multi_label + 1) ### 2 prompts and 1 caption
+            text_multi_label_x = text_x[:, 1:]
+            text_x = text_x[:, 0]
+            ####### here projection !!!! #######
+            outs.update(text_x=text_x, text_multi_label_x=text_multi_label_x)
+
+        return outs.as_return()
+ 
+
+    def project_and_mask(self, q, k, branch='online'):
+        scale = self.img_encoder.width ** -0.5
+
+        if branch == 'online':
+            q = self.q_projector(q)
+            k = self.k_projector(k)
+            attn = q @ k.transpose(-2, -1) * scale  ### no softmax for now
+        else:
+            with torch.no_grad():
+                q = self.q_projector_momentum(q)
+                k = self.k_projector_momentum(k)
+                attn = q @ k.transpose(-2, -1) * scale  ### no softmax for now
+
+        return attn
+    
+    def forward_train(self, image, text, cross_image=None, cross_entity=None, \
+                        question=None, answer=None, entity_masks=None, question_masks=None):
+        bs = image.size(0)
+        losses_dict = dict()
+        
+        ############################################################
+        ### Encode image and caption, calculate image-caption matching loss ###
+        text_outs = self.encode_text(text, as_dict=True)
+        text_x = text_outs['text_x']  # [B, C]
+        image_outs = self.encode_image(image, as_dict=True, return_feat=True, return_attn=True)
+        image_x = image_outs['image_x'] # [B, C]
+        
+        matchingloss = self.loss(image_x, text_x) 
+        losses_dict['matching'] = matchingloss
+        
+        
+        ############################################################
+        ### Encode question/answer and calculate masked entity modeling loss (if necessary) ###
+        entityloss = .0
+        if self.use_entityloss:
+            visual_feat = image_outs['image_feat_before_proj'] # unprojected group token features [b, k, d_v]
+            ### Encode questions ###
+            question_feat = self.encode_text(question, as_dict=True)['text_feat_before_proj']  ## unprojected word tokens, [B, L, d_t]
+            current_question_masks = question['attention_mask'] if isinstance(question, dict) else None
+            ### Encode answer ###
+            answer_feat = self.encode_text(answer, as_dict=True)['text_x']  # projected answer embedding, #[B, d]
+            ### project the group feature/question feature to the common multimodal space ###
+            visual_feat = self.align_proj_img(visual_feat)
+            question_feat = self.align_proj_text(question_feat)
+            ### calculate entity loss ### 
+            question_out = self.multimodal_groupingblock(visual_feat, question_feat, entity_masks=entity_masks, question_masks=current_question_masks) #[b, d_t]
+            question_out = self.bridge_projector(question_out) #[b, d]
+            entityloss = self.loss(question_out, answer_feat)            
+            
+            losses_dict['entity'] = entityloss
+        ############################################################
+        ### Encode cross-image and calculate mask loss ###
+        maskloss = .0
+        if self.use_maskloss:                
+            assert cross_image is not None and cross_entity is not None
+            
+            image_outs3 = self.encode_image(cross_image, as_dict=True, return_feat=True, return_attn=True, momentum=True)
+            # total_stages = len(image_outs3['attn_dicts'])
+            attn_q = image_outs['attn_dicts'][0]['q'].squeeze(1)
+            attn_k = image_outs['attn_dicts'][0]['k'].squeeze(1)
+            attn_q_cross = image_outs3['attn_dicts'][0]['q'].squeeze(1)
+            attn_k_cross = image_outs3['attn_dicts'][0]['k'].squeeze(1)
+
+            attn_map3 = self.project_and_mask(attn_q_cross, attn_k_cross)
+            attn_map_cross1 = self.project_and_mask(attn_q, attn_k_cross)   # the mask to match image
+            
+            
+            def compute_cross_loss(mask1, mask2, cross_entity, groups, indicator='none'):
+                mask1 = rearrange(mask1, 'b k (h w) -> b k h w', h = 14, w = 14) # hard coded this for now, [b, h, w] 
+                mask2 = rearrange(mask2, 'b k (h w) -> b k h w', h = 14, w = 14) # hard coded this for now, [b, h, w] 
+                mask1 = F.interpolate(mask1, size=(224, 224), mode='bilinear', align_corners=True)
+                mask2 = F.interpolate(mask2, size=(224, 224), mode='bilinear', align_corners=True)
+                
+                ###### get the representation of the sampled_noun and measure the similarity ###############
+                if cross_entity is not None:
+                    with torch.no_grad():
+                        noun_feat = self.encode_text(cross_entity, as_dict=True)['text_x']  # [bs, d_c]
+                        group_logits = (groups @ noun_feat.unsqueeze(-1)).squeeze(-1) #[bs, k]
+                        num_groups = group_logits.size(1)
+                        topk_logits = torch.topk(group_logits, k=int(num_groups*self.group_ratio), largest=False)[1]
+                        
+                    mask1[torch.arange(bs).unsqueeze(1), topk_logits] = mask1[torch.arange(bs).unsqueeze(1), topk_logits].detach()
+                ############################################################################################
+                return self.mask_loss(mask1, mask2.detach(), self.cross_threshold, indicator=indicator)
+            
+            maskloss_cross = compute_cross_loss(attn_map_cross1, attn_map3, cross_entity, image_outs['image_feat'], indicator='none')
+        
+            if self.dual_dice:
+                dual_image_outs = self.encode_image(image, as_dict=True, return_feat=True, return_attn=True, momentum=True)
+                dual_image_outs3 = self.encode_image(cross_image, as_dict=True, return_feat=True, return_attn=True)
+
+                dual_attn_q = dual_image_outs['attn_dicts'][0]['q'].squeeze(1)
+                dual_attn_k = dual_image_outs['attn_dicts'][0]['k'].squeeze(1)
+                dual_attn_q_cross = dual_image_outs3['attn_dicts'][0]['q'].squeeze(1)
+                dual_attn_k_cross = dual_image_outs3['attn_dicts'][0]['k'].squeeze(1)
+                
+                dual_attn_map = self.project_and_mask(dual_attn_q, dual_attn_k)
+                dual_attn_map_cross = self.project_and_mask(dual_attn_q_cross, dual_attn_k)
+                
+                dual_maskloss = compute_cross_loss(dual_attn_map_cross, dual_attn_map, cross_entity, dual_image_outs3['image_feat'], indicator='cross')
+                maskloss_cross = (maskloss_cross + dual_maskloss) * 0.5
+                
+            maskloss = maskloss_cross
+            losses_dict['mask'] = maskloss
+            
+        ############################################################
+        ### total loss ###
+        if self.use_entityloss and self.use_maskloss: ### for 2nd stage ###
+            losses = matchingloss + self.entity_weight * entityloss + self.maskloss_weight * maskloss
+        elif self.use_entityloss: ### for 1st stage ###
+            losses = matchingloss + self.entity_weight * entityloss
+        else: ### baseline ###
+            losses = matchingloss
+        
+        if self.with_multi_label:
+            image_multi_label_x = image_x.unsqueeze(1)
+            text_multi_label_x = text_outs['text_multi_label_x']
+            loss_multi_label = self.multi_label_loss(image_multi_label_x, text_multi_label_x) * self.multi_label_loss_weight
+            losses_dict['multi_label'] = loss_multi_label
+            losses += loss_multi_label
+            
+        losses_dict['loss'] = losses
+        return losses_dict
+
+    def forward_test(self, image, text):
+        return self.zero_shot_pred(image, text)
+
+    def forward(self, image, text, cross_image=None, cross_entity=None, \
+                 question=None, answer=None, entity_masks=None, question_masks=None):
+        """
+        
+        Args:
+            image: [b, 3, 224, 224] raw input image
+            text: [b, L] caption embedding after tokenisation with length L
+            cross_image: [b, 3, 224, 224] the image that shares the same entity with the input image
+            cross_entity: [b, L] text embedding of the shared entity after tokenisation 
+            question: [b, L] question embedding after tokenisation
+            answer: [b, L]  prompted answer embedding after tokenisation
+            entity_masks: [b, L] 
+            question_masks: [b, L]
+            
+        """
+        if self.training:
+            return self.forward_train(image=image, text=text, cross_image=cross_image, cross_entity=cross_entity, \
+                                    question=question, answer=answer, entity_masks=entity_masks, question_masks=question_masks)
+        else:
+            return self.forward_test(image, text)
+
+    @torch.no_grad()
+    def build_text_embedding(self, text, tokenizer=None, num_classes=20):
+        """
+
+        Args:
+            text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH]
+            
+            distilbert:
+                text (list) [classes * numtemplates] for distilbert, num_classes: 20 for voc by default, 1000 for IN1K
+                num_classes 暂时没用
+        Returns:
+
+        """
+        if self.text_encoder_type in ['DistilBert','Bert', 'BertMedium', 'Roberta']:
+            assert tokenizer is not None
+            text_data = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
+            text_data = {key: val.cuda() for key, val in text_data.items()}
+            text_tokens = self.encode_text(text_data, as_dict=True, forward_template=True)['text_x']
+        else:
+            text = text.to(next(self.parameters()).device)
+            num_classes, num_templates = text.shape[:2]
+            text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates)
+            text_tokens = self.encode_text(text, as_dict=True, forward_template=True)['text_x']
+        # [N, T, C]
+        # text_tokens = rearrange(text_tokens, '(n t) c -> n t c', n=num_classes, t=num_templates)
+        text_tokens = rearrange(text_tokens, '(n t) c -> n t c', n=num_classes)
+        # [N, C]
+        text_tokens = text_tokens.mean(dim=1)
+        text_tokens = F.normalize(text_tokens, dim=-1)
+
+        return text_tokens
+
+
+    @torch.no_grad()
+    def build_text_embedding_without_projection(self, text):
+        """
+
+        Args:
+            text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH]
+
+        Returns:
+
+        """
+        text = text.to(next(self.parameters()).device)
+        num_classes, num_templates = text.shape[:2]
+        text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates)
+        text_tokens = self.encode_text(text, as_dict=True, forward_template=True)['text_x_before_proj']
+        
+        # [N, T, C]
+        text_tokens = rearrange(text_tokens, '(n t) c -> n t c', n=num_classes, t=num_templates)
+        # [N, C]
+        text_tokens = text_tokens.mean(dim=1)
+        text_tokens = F.normalize(text_tokens, dim=-1)
+
+        return text_tokens
+
+    @torch.no_grad()
+    def zero_shot_pred(self, image, text):
+        # [B, C]
+        image_features = self.encode_image(image)
+        image_features = F.normalize(image_features, dim=-1)
+        # cosine similarity as logits
+        logits_per_image = image_features @ text.t()
+        return logits_per_image
+        
+

+ 319 - 0
models/transformer.py

@@ -0,0 +1,319 @@
+# -------------------------------------------------------------------------
+# MIT License
+#
+# Copyright (c) 2021 OpenAI
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+
+import torch
+import torch.utils.checkpoint as checkpoint
+from torch import nn
+
+from .builder import MODELS
+from .misc import Result
+from .utils import ResidualAttentionBlock
+from ipdb import set_trace
+import clip
+from transformers import AutoModel
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+class Transformer(nn.Module):
+
+    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_checkpoint=False):
+        super().__init__()
+        self.width = width
+        self.layers = layers
+        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+        proj_std = (self.width**-0.5) * ((2 * self.layers)**-0.5)
+        attn_std = self.width**-0.5
+        fc_std = (2 * self.width)**-0.5
+        for block in self.resblocks:
+            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+        self.use_checkpoint = use_checkpoint
+
+    def forward(self, x: torch.Tensor):
+        for i, resblock in enumerate(self.resblocks):
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(resblock, x)
+            else:
+                x = resblock(x)
+        return x
+
+@MODELS.register_module()
+class DistilBert(nn.Module):
+    def __init__(
+        self,
+        context_length: int,
+        width: int,
+        layers: int,
+        vocab_size,
+        use_checkpoint=False,
+        pretrained=True,
+        fixed=True,
+    ):
+        super().__init__()
+        self.transformer = AutoModel.from_pretrained('distilbert-base-uncased', output_hidden_states=True)
+        self.transformer.train()
+        self.width = width
+    
+        if fixed is True:
+            for p in self.transformer.parameters():
+                p.requires_grad = False
+
+        if pretrained is False:
+            self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def forward(self, x, as_dict=True):
+        outs = Result(as_dict=as_dict)
+        out_x = self.transformer(**x)
+        out_hidden = out_x.last_hidden_state[:, 0, :]
+        last_hidden = out_x.hidden_states[-1]
+
+        outs.append(out_hidden, name='x')
+        outs.append(last_hidden, name='all_tokens')
+        return outs.as_return()
+
+@MODELS.register_module()
+class Bert(nn.Module):
+    def __init__(
+        self,
+        context_length: int,
+        width: int,
+        layers: int,
+        vocab_size,
+        use_checkpoint=False,
+        pretrained=True,
+        fixed=True,
+    ):
+        super().__init__()
+        self.transformer = AutoModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
+        self.transformer.train()
+        self.width = width
+    
+        if fixed is True:
+            for p in self.transformer.parameters():
+                p.requires_grad = False
+
+        if pretrained is False:
+            self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def forward(self, x, as_dict=True):
+        outs = Result(as_dict=as_dict)
+        out_x = self.transformer(**x)
+        out_hidden = out_x.last_hidden_state[:, 0, :]
+        last_hidden = out_x.hidden_states[-1]
+
+        outs.append(out_hidden, name='x')
+        outs.append(last_hidden, name='all_tokens')
+        return outs.as_return()
+    
+@MODELS.register_module()
+class Roberta(nn.Module):
+    def __init__(
+        self,
+        context_length: int,
+        width: int,
+        layers: int,
+        vocab_size,
+        use_checkpoint=False,
+        pretrained=True,
+        fixed=True,
+    ):
+        super().__init__()
+        self.transformer = AutoModel.from_pretrained('roberta-base', output_hidden_states=True, cache_dir='/mnt/petrelfs/xujilan/checkpoints/')
+        self.transformer.train()
+        self.width = width
+    
+        if fixed is True:
+            for p in self.transformer.parameters():
+                p.requires_grad = False
+
+        if pretrained is False:
+            self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def forward(self, x, question=None, as_dict=True):
+        outs = Result(as_dict=as_dict)
+        out_x = self.transformer(**x)
+        out_hidden = out_x.last_hidden_state[:, 0, :]
+        last_hidden = out_x.hidden_states[-1]
+
+        outs.append(out_hidden, name='x')
+        outs.append(last_hidden, name='all_tokens')
+        return outs.as_return()
+
+@MODELS.register_module()
+class BertMedium(nn.Module):
+    def __init__(
+        self,
+        context_length: int,
+        width: int,
+        layers: int,
+        vocab_size,
+        use_checkpoint=False,
+        pretrained=True,
+        fixed=True,
+    ):
+        super().__init__()
+        self.transformer = AutoModel.from_pretrained('prajjwal1/bert-medium', output_hidden_states=True)
+        self.transformer.train()
+        self.width = width
+    
+        if fixed is True:
+            for p in self.transformer.parameters():
+                p.requires_grad = False
+
+        if pretrained is False:
+            self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def forward(self, x, as_dict=True):
+        outs = Result(as_dict=as_dict)
+        out_x = self.transformer(**x)
+        out_hidden = out_x.last_hidden_state[:, 0, :]
+        last_hidden = out_x.hidden_states[-1]
+
+        outs.append(out_hidden, name='x')
+        outs.append(last_hidden, name='all_tokens')
+        return outs.as_return()
+
+@MODELS.register_module()
+class TextTransformer(nn.Module):
+
+    def __init__(
+        self,
+        context_length: int,
+        width: int,
+        layers: int,
+        vocab_size,
+        use_checkpoint=False,
+        pretrained=True,
+        fixed=True,
+    ):
+
+        super().__init__()
+        heads = width // 64
+        self.context_length = context_length
+        self.width = width
+        self.transformer = Transformer(
+            width=width,
+            layers=layers,
+            heads=heads,
+            attn_mask=self.build_attention_mask(),
+            use_checkpoint=use_checkpoint)
+
+        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
+        self.ln_final = nn.LayerNorm(width)
+        self.token_embedding = nn.Embedding(vocab_size, width)
+        nn.init.normal_(self.token_embedding.weight, std=0.02)
+
+        clip_model, _ = clip.load('ViT-B/16', device='cuda', jit=False)
+        self.text_projection = nn.Parameter(torch.empty(clip_model.text_projection.shape))
+        nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+        # initialization
+        nn.init.normal_(self.positional_embedding, std=0.01)
+
+        if pretrained:
+            print('loading clip weights for text encoder')
+            self.reload_clip_weights(clip_model)
+        if fixed:
+            print('freezing text encoder')
+            self.freeze_text_encoder()
+
+    def freeze_text_encoder(self):
+        for p in self.parameters():
+            p.requires_grad=False
+
+    def reload_clip_weights(self, clip_model):
+        text_dict = clip_model.state_dict()
+        msg = self.load_state_dict(text_dict, strict=False)
+
+    def build_attention_mask(self):
+        # lazily create causal attention mask, with full attention between the vision tokens
+        # pytorch uses additive attention mask; fill with -inf
+        mask = torch.empty(self.context_length, self.context_length)
+        mask.fill_(float('-inf'))
+        mask.triu_(1)  # zero out the lower diagonal
+        return mask
+
+    def forward(self, text, *, as_dict=True):
+        x = self.token_embedding(text)
+        outs = Result(as_dict=as_dict)
+        x = x + self.positional_embedding
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.ln_final(x)
+
+        ### w/o text projection ###
+        # all_tokens = x.clone()
+        # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
+
+        ### w/ text projection ###
+        all_tokens = x.clone() @ self.text_projection
+        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+        outs.append(x, name='x')
+        outs.append(all_tokens, name='all_tokens')
+
+        return outs.as_return()

+ 60 - 0
models/utils.py

@@ -0,0 +1,60 @@
+# -------------------------------------------------------------------------
+# MIT License
+#
+# Copyright (c) 2021 OpenAI
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from collections import OrderedDict
+
+import torch
+from torch import nn
+
+
+class QuickGELU(nn.Module):
+
+    def forward(self, x: torch.Tensor):
+        return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+
+    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+        super().__init__()
+
+        self.attn = nn.MultiheadAttention(d_model, n_head)
+        self.ln_1 = nn.LayerNorm(d_model)
+        self.mlp = nn.Sequential(OrderedDict([
+            ('c_fc', nn.Linear(d_model, d_model * 4)), 
+            ('gelu', QuickGELU()),
+            ('c_proj', nn.Linear(d_model * 4, d_model))]))
+        self.ln_2 = nn.LayerNorm(d_model)
+        self.attn_mask = attn_mask
+
+    def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor):
+        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=key_padding_mask)[0]
+
+    def forward(self, x: torch.Tensor, key_padding_mask=None):
+        x = x + self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)
+        x = x + self.mlp(self.ln_2(x))
+        return x

+ 314 - 0
models/vision_transformer.py

@@ -0,0 +1,314 @@
+# Copyright (c) ByteDance, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+"""
+Mostly copy-paste from DINO and timm library:
+https://github.com/facebookresearch/dino
+https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+"""
+
+import math
+import torch
+import torch.nn as nn
+
+from functools import partial
+from timm.models.registry import register_model
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+    # type: (Tensor, float, float, float, float) -> Tensor
+    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+    if drop_prob == 0. or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+    random_tensor.floor_()  # binarize
+    output = x.div(keep_prob) * random_tensor
+    return output
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(self, 
+                 dim, 
+                 num_heads=8, 
+                 qkv_bias=False, 
+                 qk_scale=None, 
+                 attn_drop=0., 
+                 proj_drop=0.):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x):
+        B, N, C = x.shape
+        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]
+
+        attn = (q @ k.transpose(-2, -1)) * self.scale
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x, attn
+
+class Block(nn.Module):
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., 
+                 attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, init_values=0):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        if init_values > 0:
+            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+        else:
+            self.gamma_1, self.gamma_2 = None, None
+
+    def forward(self, x, return_attention=False):
+        y, attn = self.attn(self.norm1(x))
+        if return_attention:
+            return attn
+        if self.gamma_1 is None:
+            x = x + self.drop_path(y)
+            x = x + self.drop_path(self.mlp(self.norm2(x)))
+        else:
+            x = x + self.drop_path(self.gamma_1 * y)
+            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+        return x
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        num_patches = (img_size // patch_size) * (img_size // patch_size)
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+            
+    def forward(self, x):
+        B, C, H, W = x.shape
+        return self.proj(x)
+
+class VisionTransformer(nn.Module):
+    """ Vision Transformer """
+    def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
+                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+                 drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), return_all_tokens=False, 
+                 init_values=0, use_mean_pooling=False, masked_im_modeling=False):
+        super().__init__()
+        self.num_features = self.embed_dim = embed_dim
+        self.return_all_tokens = return_all_tokens
+
+        self.patch_embed = PatchEmbed(
+            img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+        num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+        self.blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 
+                init_values=init_values)
+            for i in range(depth)])
+
+        self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
+        self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
+        # Classifier head
+        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+        trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        self.apply(self._init_weights)
+
+        # masked image modeling
+        print('whether use masked im modeling', masked_im_modeling)
+        self.masked_im_modeling = masked_im_modeling
+        if masked_im_modeling:
+            self.masked_embed = nn.Parameter(torch.zeros(1, embed_dim))
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def interpolate_pos_encoding(self, x, w, h):
+        npatch = x.shape[1] - 1
+        N = self.pos_embed.shape[1] - 1
+        if npatch == N and w == h:
+            return self.pos_embed
+        class_pos_embed = self.pos_embed[:, 0]
+        patch_pos_embed = self.pos_embed[:, 1:]
+        dim = x.shape[-1]
+        w0 = w // self.patch_embed.patch_size
+        h0 = h // self.patch_embed.patch_size
+        # we add a small number to avoid floating point error in the interpolation
+        # see discussion at https://github.com/facebookresearch/dino/issues/8
+        w0, h0 = w0 + 0.1, h0 + 0.1
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+            mode='bicubic',
+        )
+        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+    def prepare_tokens(self, x, mask=None):
+        B, nc, w, h = x.shape
+        # patch linear embedding
+        x = self.patch_embed(x)
+
+        # mask image modeling
+        if mask is not None:
+            x = self.mask_model(x, mask)
+        x = x.flatten(2).transpose(1, 2)
+
+        # add the [CLS] token to the embed patch tokens
+        cls_tokens = self.cls_token.expand(B, -1, -1)
+        x = torch.cat((cls_tokens, x), dim=1)
+
+        # add positional encoding to each token
+        x = x + self.interpolate_pos_encoding(x, w, h)
+
+        return self.pos_drop(x)
+
+    def forward(self, x, return_all_tokens=None, mask=None):
+        # mim
+        if self.masked_im_modeling:
+            assert mask is not None
+            #print('whats up here: ' , x.shape, mask.shape)
+            x = self.prepare_tokens(x, mask=mask)
+        else:
+            x = self.prepare_tokens(x)
+
+        for blk in self.blocks:
+            x = blk(x)
+
+        x = self.norm(x)
+        if self.fc_norm is not None:
+            x[:, 0] = self.fc_norm(x[:, 1:, :].mean(1))
+        
+        return_all_tokens = self.return_all_tokens if \
+            return_all_tokens is None else return_all_tokens
+        if return_all_tokens:
+            return x
+        return x[:, 0]
+
+    def get_last_selfattention(self, x):
+        x = self.prepare_tokens(x)
+        for i, blk in enumerate(self.blocks):
+            if i < len(self.blocks) - 1:
+                x = blk(x)
+            else:
+                # return attention of the last block
+                return blk(x, return_attention=True)
+
+    def get_intermediate_layers(self, x, n=1):
+        x = self.prepare_tokens(x)
+        # we return the output tokens from the `n` last blocks
+        output = []
+        for i, blk in enumerate(self.blocks):
+            x = blk(x)
+            if len(self.blocks) - i <= n:
+                output.append(self.norm(x))
+        return output
+        
+    def get_num_layers(self):
+        return len(self.blocks)
+
+    def mask_model(self, x, mask):
+        x.permute(0, 2, 3, 1)[mask, :] = self.masked_embed.to(x.dtype)
+        return x
+
+def vit_mini(patch_size=16, **kwargs):
+    model = VisionTransformer(
+        patch_size=patch_size, embed_dim=384, depth=4, num_heads=3, mlp_ratio=4,
+        qkv_bias=True, **kwargs)
+    return model
+
+def vit_tiny(image_size=[224], patch_size=16, **kwargs):
+    model = VisionTransformer(
+        image_size=image_size, patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
+        qkv_bias=True, **kwargs)
+    return model
+
+def vit_small(image_size=[224], patch_size=16, **kwargs):
+    model = VisionTransformer(
+        image_size=image_size, patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
+        qkv_bias=True, **kwargs)
+    return model
+
+def vit_base(image_size=[224], patch_size=16, **kwargs):
+    model = VisionTransformer(
+        image_size=image_size, patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
+        qkv_bias=True, **kwargs)
+    return model
+
+def vit_large(image_size=[224], patch_size=16, **kwargs):
+    model = VisionTransformer(
+        image_size=image_size, patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
+        qkv_bias=True, **kwargs)
+    return model

+ 24 - 0
requirements.txt

@@ -0,0 +1,24 @@
+braceexpand==0.1.7
+datadings==3.4.6
+diffdist==0.1
+einops==0.4.1
+ftfy==6.0.3
+ipdb==0.13.9
+ipython==8.11.0
+kestrel==0.0.1
+nltk==3.8.1
+omegaconf==2.1.0
+# opencv_python==4.6.0.66
+pandarallel==1.6.4
+regex==2022.3.15
+requests==2.28.1
+scipy==1.7.3
+spacy==3.5.0
+tensorboard==2.11.0
+termcolor==2.2.0
+timm==0.6.12
+tqdm==4.64.1
+transformers==4.21.0
+wandb==0.13.7
+webdataset==0.1.103
+# yfcc100m==1.0.1

BIN
segmentation/.DS_Store


BIN
segmentation/configs/.DS_Store


BIN
segmentation/configs/_base_/.DS_Store


+ 15 - 0
segmentation/configs/_base_/custom_import.py

@@ -0,0 +1,15 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+custom_imports = dict(
+    imports=['segmentation.datasets.coco_object', 'segmentation.datasets.pascal_voc'], allow_failed_imports=False)

+ 47 - 0
segmentation/configs/_base_/datasets/ade20k.py

@@ -0,0 +1,47 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+
+_base_ = ['../custom_import.py']
+# dataset settings
+dataset_type = 'ADE20KDataset'
+data_root = '/mnt/petrelfs/xujilan/data/ADEChallengeData2016/'
+
+img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(2048, 448),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/validation',
+        ann_dir='annotations/validation',
+        # split='ImageSets/Segmentation/val.txt',
+        pipeline=test_pipeline))
+
+test_cfg = dict(bg_thresh=.95, mode='slide', stride=(224, 224), crop_size=(448, 448))

+ 45 - 0
segmentation/configs/_base_/datasets/coco.py

@@ -0,0 +1,45 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+_base_ = ['../custom_import.py']
+# dataset settings
+dataset_type = 'COCOObjectDataset'
+data_root = '/mnt/petrelfs/xujilan/data/coco/'
+img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        # img_scale=(2048, 512),
+        img_scale=(2048, 448),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/val2017',
+        ann_dir='annotations/val2017_cocoobject',
+        pipeline=test_pipeline))
+
+test_cfg = dict(bg_thresh=.9, mode='slide', stride=(224, 224), crop_size=(448, 448))

+ 46 - 0
segmentation/configs/_base_/datasets/coco_stuff.py

@@ -0,0 +1,46 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+_base_ = ['../custom_import.py']
+# dataset settings
+dataset_type = 'COCOStufferDataset'
+data_root = '/mnt/petrelfs/xujilan/data/coco/'
+img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        # img_scale=(2048, 512),
+        img_scale=(2048, 448),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/val2017',
+        ann_dir='annotations/val2017',
+        pipeline=test_pipeline))
+
+# test_cfg = dict(bg_thresh=.95, mode='whole')
+test_cfg = dict(bg_thresh=.95, mode='slide', stride=(224, 224), crop_size=(448, 448))

+ 45 - 0
segmentation/configs/_base_/datasets/pascal_context.py

@@ -0,0 +1,45 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+_base_ = ['../custom_import.py']
+# dataset settings
+dataset_type = 'PascalContextDataset'
+data_root = '/mnt/petrelfs/xujilan/data/pascal_context/VOCdevkit/VOC2010/'
+img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(2048, 448),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClassContext',
+        split='ImageSets/SegmentationContext/val.txt',
+        pipeline=test_pipeline))
+
+test_cfg = dict(bg_thresh=.35, mode='slide', stride=(224, 224), crop_size=(448, 448))

+ 58 - 0
segmentation/configs/_base_/datasets/pascal_voc12 copy.py

@@ -0,0 +1,58 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+_base_ = ['../custom_import.py']
+# dataset settings
+dataset_type = 'PascalVOCDataset'
+data_root = '/mnt/cache/share_data/DSK_datasets/VOCdevkit/VOC2012'
+img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(2048, 448),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(2048, 448),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClass',
+        split='ImageSets/Segmentation/val.txt',
+        pipeline=test_pipeline))
+
+test_cfg = dict(bg_thresh=.95, mode='slide', stride=(224, 224), crop_size=(448, 448))

+ 46 - 0
segmentation/configs/_base_/datasets/pascal_voc12.py

@@ -0,0 +1,46 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+
+_base_ = ['../custom_import.py']
+# dataset settings
+dataset_type = 'PascalVOCDataset'
+data_root = '/mnt/petrelfs/xujilan/data/VOCdevkit/VOC2012'
+img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(2048, 448),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClass',
+        split='ImageSets/Segmentation/val.txt',
+        pipeline=test_pipeline))
+
+test_cfg = dict(bg_thresh=.95, mode='slide', stride=(224, 224), crop_size=(448, 448))

+ 23 - 0
segmentation/datasets/__init__.py

@@ -0,0 +1,23 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+
+from .coco_object import COCOObjectDataset
+from .pascal_context import PascalContextDataset
+from .pascal_voc import PascalVOCDataset
+from .coco_stuff import COCOStufferDataset
+from .ade20k import ADE20KDataset
+
+__all__ = ['COCOObjectDataset', 'PascalContextDataset', 'PascalVOCDataset', 'COCOStufferDataset', 'ADE20KDataset']

+ 48 - 0
segmentation/datasets/ade20k.py

@@ -0,0 +1,48 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+
+from mmseg.datasets import DATASETS
+from mmseg.datasets import ADE20KDataset as _ADE20KDataset
+
+
+@DATASETS.register_module(force=True)
+class ADE20KDataset(_ADE20KDataset):
+    CLASSES = (
+        'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
+        'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
+        'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
+        'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
+        'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
+        'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
+        'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
+        'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
+        'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
+        'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
+        'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
+        'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
+        'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
+        'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
+        'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
+        'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
+        'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
+        'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
+        'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
+        'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
+        'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
+        'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
+        'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
+        'clock', 'flag')
+    

+ 51 - 0
segmentation/datasets/coco_object.py

@@ -0,0 +1,51 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+
+from mmseg.datasets import DATASETS, CustomDataset
+
+
+@DATASETS.register_module()
+class COCOObjectDataset(CustomDataset):
+    """COCO-Object dataset.
+
+    1 bg class + first 80 classes from the COCO-Stuff dataset.
+    """
+
+    CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
+               'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
+               'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
+               'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
+               'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
+               'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
+               'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
+               'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
+               'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
+
+    PALETTE = [[0, 0, 0], [0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], [0, 64, 64], [0, 192, 224],
+               [0, 192, 192], [128, 192, 64], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 0, 64],
+               [0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
+               [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], [64, 128, 32], [0, 160, 0], [0, 0, 0],
+               [192, 128, 160], [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], [192, 128, 32],
+               [128, 96, 128], [0, 0, 128], [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],
+               [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], [192, 0, 32],
+               [128, 96, 0], [128, 0, 192], [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],
+               [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
+               [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], [128, 192, 192], [0, 0, 160],
+               [192, 160, 128], [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], [64, 160, 0],
+               [0, 64, 0], [192, 128, 224], [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0]]
+
+    def __init__(self, **kwargs):
+        super(COCOObjectDataset, self).__init__(img_suffix='.jpg', seg_map_suffix='_instanceTrainIds.png', **kwargs)

+ 58 - 0
segmentation/datasets/coco_stuff.py

@@ -0,0 +1,58 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+from mmseg.datasets import DATASETS, CustomDataset
+
+
+@DATASETS.register_module()
+class COCOStufferDataset(CustomDataset):
+    """COCO-Stuff dataset.
+
+    1 bg class + 80 things + 91 stuff classes from the COCO-Stuff dataset.
+    """
+
+    CLASSES = ('background','person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 
+               'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
+               'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 
+               'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 
+               'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 
+               'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
+               'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', 
+               'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', 
+               'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', 
+               'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 
+               'gravel', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 
+               'net', 'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 'platform', 'playingfield', 'railing', 'railroad', 'river', 'road', 
+               'rock', 'roof', 'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper', 'snow', 'solid-other', 'stairs', 'stone', 'straw',
+               'structural-other', 'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable', 'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel', 
+               'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', 'window-blind', 'window-other', 'wood')
+
+    PALETTE = [[0, 0, 0], [0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], [0, 64, 64], [0, 192, 224],
+               [0, 192, 192], [128, 192, 64], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 0, 64],
+               [0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
+               [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], [64, 128, 32], [0, 160, 0], [0, 0, 0],
+               [192, 128, 160], [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], [192, 128, 32],
+               [128, 96, 128], [0, 0, 128], [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],
+               [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], [192, 0, 32],
+               [128, 96, 0], [128, 0, 192], [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],
+               [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
+               [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], [128, 192, 192], [0, 0, 160],
+               [192, 160, 128], [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], [64, 160, 0],
+               [0, 64, 0], [192, 128, 224], [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0],
+               [1, 0, 0], [1, 0, 1], [1, 0, 2], [1, 0, 3], [1, 0, 4], [1, 0, 5], [1, 0, 6], [1, 0, 7], [1, 0, 8], [1, 0, 9], [1, 0, 10], [1, 0, 11], [1, 0, 12], [1, 0, 13], [1, 0, 14], [1, 0, 15], [1, 0, 16], [1, 0, 17], [1, 0, 18], [1, 0, 19], [1, 0, 20], [1, 0, 21], [1, 0, 22], [1, 0, 23], [1, 0, 24], [1, 0, 25], [1, 0, 26], [1, 0, 27], [1, 0, 28], [1, 0, 29], [1, 0, 30], [1, 0, 31], [1, 0, 32], [1, 0, 33], [1, 0, 34], [1, 0, 35], [1, 0, 36], [1, 0, 37], [1, 0, 38], [1, 0, 39], [1, 0, 40], [1, 0, 41], [1, 0, 42], [1, 0, 43], [1, 0, 44], [1, 0, 45], [1, 0, 46], [1, 0, 47], [1, 0, 48], [1, 0, 49], [1, 0, 50], [1, 0, 51], [1, 0, 52], [1, 0, 53], [1, 0, 54], [1, 0, 55], [1, 0, 56], [1, 0, 57], [1, 0, 58], [1, 0, 59], [1, 0, 60], [1, 0, 61], [1, 0, 62], [1, 0, 63], [1, 0, 64], [1, 0, 65], [1, 0, 66], [1, 0, 67], [1, 0, 68], [1, 0, 69], [1, 0, 70], [1, 0, 71], [1, 0, 72], [1, 0, 73], [1, 0, 74], [1, 0, 75], [1, 0, 76], [1, 0, 77], [1, 0, 78], [1, 0, 79], [1, 0, 80], [1, 0, 81], [1, 0, 82], [1, 0, 83], [1, 0, 84], [1, 0, 85], [1, 0, 86], [1, 0, 87], [1, 0, 88], [1, 0, 89], [1, 0, 90],
+               ]
+
+    def __init__(self, **kwargs):
+        super(COCOStufferDataset, self).__init__(img_suffix='.jpg', seg_map_suffix='_instanceTrainIds.png', **kwargs)

+ 26 - 0
segmentation/datasets/pascal_context.py

@@ -0,0 +1,26 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from mmseg.datasets import DATASETS
+from mmseg.datasets import PascalContextDataset as _PascalContextDataset
+
+
+@DATASETS.register_module(force=True)
+class PascalContextDataset(_PascalContextDataset):
+
+    CLASSES = ('background', 'airplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat', 'book',
+               'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
+               'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse',
+               'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', 'plant', 'road',
+               'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train', 'tree',
+               'truck', 'monitor', 'wall', 'water', 'window', 'wood')

+ 25 - 0
segmentation/datasets/pascal_voc.py

@@ -0,0 +1,25 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from mmseg.datasets import DATASETS
+from mmseg.datasets import PascalVOCDataset as _PascalVOCDataset
+
+
+@DATASETS.register_module(force=True)
+class PascalVOCDataset(_PascalVOCDataset):
+
+    CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
+              'table', 'dog', 'horse', 'motorbike', 'person', 'plant', 'sheep', 'sofa', 'train', 'monitor')
+    # CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
+    #            'table', 'dog', 'horse', 'motorbike', 'person', 'plant', 'sheep', 'sofa', 'train', 'monitor')
+    

+ 22 - 0
segmentation/evaluation/__init__.py

@@ -0,0 +1,22 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+from .builder import build_seg_dataloader, build_seg_dataset, build_seg_demo_pipeline, build_seg_inference
+from .group_vit_seg import GROUP_PALETTE, GroupViTSegInference
+
+__all__ = [
+    'GroupViTSegInference', 'build_seg_dataset', 'build_seg_dataloader', 'build_seg_inference',
+    'build_seg_demo_pipeline', 'GROUP_PALETTE'
+]

+ 120 - 0
segmentation/evaluation/builder.py

@@ -0,0 +1,120 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+import mmcv
+from mmseg.datasets import build_dataloader, build_dataset
+from mmseg.datasets.pipelines import Compose
+from omegaconf import OmegaConf
+from utils import build_dataset_class_tokens, build_dataset_class_lists
+
+from .group_vit_seg import GroupViTSegInference
+from ipdb import set_trace
+
+
+def build_seg_dataset(config):
+    """Build a dataset from config."""
+    cfg = mmcv.Config.fromfile(config.cfg)
+    dataset = build_dataset(cfg.data.test)
+    return dataset
+
+
+def build_seg_dataloader(dataset):
+
+    data_loader = build_dataloader(
+        dataset,
+        samples_per_gpu=1,
+        workers_per_gpu=1,
+        dist=True,
+        shuffle=False,
+        persistent_workers=True,
+        pin_memory=False)
+    return data_loader
+
+
+def build_seg_inference(model, dataset, text_transform, config, tokenizer=None):
+    cfg = mmcv.Config.fromfile(config.cfg)
+    if len(config.opts):
+        cfg.merge_from_dict(OmegaConf.to_container(OmegaConf.from_dotlist(OmegaConf.to_container(config.opts))))
+    with_bg = dataset.CLASSES[0] == 'background'
+    if with_bg:
+        classnames = dataset.CLASSES[1:]
+    else:
+        classnames = dataset.CLASSES
+    
+    if tokenizer is not None:
+        text_tokens = build_dataset_class_lists(config.template, classnames)
+        text_embedding = model.build_text_embedding(text_tokens, tokenizer, num_classes=len(classnames))
+    else:
+        text_tokens = build_dataset_class_tokens(text_transform, config.template, classnames)
+        text_embedding = model.build_text_embedding(text_tokens, num_classes=len(classnames))
+    kwargs = dict(with_bg=with_bg)
+
+    if hasattr(cfg, 'test_cfg'):
+        kwargs['test_cfg'] = cfg.test_cfg
+    
+    seg_model = GroupViTSegInference(model, text_embedding, **kwargs)
+    print('Evaluate GroupViT during seg inference')
+
+    seg_model.CLASSES = dataset.CLASSES
+    seg_model.PALETTE = dataset.PALETTE
+
+    return seg_model
+
+
+class LoadImage:
+    """A simple pipeline to load image."""
+    cnt = 0
+    def __call__(self, results):
+        """Call function to load images into results.
+
+        Args:
+            results (dict): A result dict contains the file name
+                of the image to be read.
+
+        Returns:
+            dict: ``results`` will be returned containing loaded image.
+        """
+
+        if isinstance(results['img'], str):
+            results['filename'] = results['img']
+            results['ori_filename'] = results['img']
+        else:
+            results['filename'] = None
+            results['ori_filename'] = None
+        img = mmcv.imread(results['img'])
+        results['img'] = img
+        results['img_shape'] = img.shape
+        results['ori_shape'] = img.shape
+        return results
+
+
+def build_seg_demo_pipeline():
+    """Build a demo pipeline from config."""
+    img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+    test_pipeline = Compose([
+        LoadImage(),
+        dict(
+            type='MultiScaleFlipAug',
+            img_scale=(2048, 448),
+            flip=False,
+            transforms=[
+                dict(type='Resize', keep_ratio=True),
+                dict(type='RandomFlip'),
+                dict(type='Normalize', **img_norm_cfg),
+                dict(type='ImageToTensor', keys=['img']),
+                dict(type='Collect', keys=['img']),
+            ])
+    ])
+    return test_pipeline

+ 209 - 0
segmentation/evaluation/group_palette.txt

@@ -0,0 +1,209 @@
+128 64 128
+244 35 232
+70 70 70
+102 102 156
+190 153 153
+153 153 153
+250 170 30
+220 220 0
+107 142 35
+152 251 152
+70 130 180
+220 20 60
+255 0 0
+0 0 142
+0 0 70
+0 60 100
+0 80 100
+0 0 230
+119 11 32
+128 0 0
+0 128 0
+128 128 0
+0 0 128
+128 0 128
+0 128 128
+128 128 128
+64 0 0
+192 0 0
+64 128 0
+192 128 0
+64 0 128
+192 0 128
+64 128 128
+192 128 128
+0 64 0
+128 64 0
+0 192 0
+128 192 0
+0 64 128
+120 120 120
+180 120 120
+6 230 230
+80 50 50
+4 200 3
+120 120 80
+140 140 140
+204 5 255
+230 230 230
+4 250 7
+224 5 255
+235 255 7
+150 5 61
+120 120 70
+8 255 51
+255 6 82
+143 255 140
+204 255 4
+255 51 7
+204 70 3
+0 102 200
+61 230 250
+255 6 51
+11 102 255
+255 7 71
+128 0 0
+0 128 0
+128 128 0
+0 0 128
+128 0 128
+0 128 128
+128 128 128
+64 0 0
+192 0 0
+64 128 0
+192 128 0
+64 0 128
+192 0 128
+64 128 128
+192 128 128
+0 64 0
+128 64 0
+0 192 0
+128 192 0
+0 64 128
+255 9 224
+9 7 230
+220 220 220
+255 9 92
+112 9 255
+8 255 214
+7 255 224
+255 184 6
+10 255 71
+255 41 10
+7 255 255
+224 255 8
+102 8 255
+255 61 6
+255 194 7
+255 122 8
+0 255 20
+255 8 41
+255 5 153
+6 51 255
+235 12 255
+160 150 20
+0 163 255
+140 140 140
+250 10 15
+20 255 0
+31 255 0
+255 31 0
+255 224 0
+153 255 0
+0 0 255
+255 71 0
+0 235 255
+0 173 255
+31 0 255
+11 200 200
+255 82 0
+0 255 245
+0 61 255
+0 255 112
+0 255 133
+255 0 0
+255 163 0
+255 102 0
+194 255 0
+0 143 255
+51 255 0
+0 82 255
+0 255 41
+0 255 173
+10 0 255
+173 255 0
+0 255 153
+255 92 0
+255 0 255
+255 0 245
+255 0 102
+255 173 0
+255 0 20
+255 184 184
+0 31 255
+0 255 61
+0 71 255
+255 0 204
+0 255 194
+0 255 82
+0 10 255
+0 112 255
+51 0 255
+0 194 255
+0 122 255
+0 255 163
+255 153 0
+0 255 10
+255 112 0
+143 255 0
+82 0 255
+163 255 0
+255 235 0
+8 184 170
+133 0 255
+0 255 92
+184 0 255
+255 0 31
+0 184 255
+0 214 255
+255 0 112
+92 255 0
+0 224 255
+112 224 255
+70 184 160
+163 0 255
+153 0 255
+71 255 0
+255 0 163
+255 204 0
+255 0 143
+0 255 235
+133 255 0
+255 0 235
+245 0 255
+255 0 122
+255 245 0
+10 190 212
+214 255 0
+0 204 255
+20 0 255
+255 255 0
+0 153 255
+0 41 255
+0 255 204
+41 0 255
+41 255 0
+173 0 255
+0 245 255
+71 0 255
+122 0 255
+0 255 184
+0 92 255
+184 255 0
+0 133 255
+255 214 0
+25 194 194
+102 255 0
+92 0 255

+ 378 - 0
segmentation/evaluation/group_vit_seg.py

@@ -0,0 +1,378 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+
+import os.path as osp
+
+import matplotlib.pyplot as plt
+import mmcv
+import numpy as np
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from mmseg.models import EncoderDecoder
+from PIL import Image
+from utils import get_logger
+import cv2
+
+GROUP_PALETTE = np.loadtxt(osp.join(osp.dirname(osp.abspath(__file__)), 'group_palette.txt'), dtype=np.uint8)[:, ::-1]
+
+from ipdb import set_trace
+
+def resize_attn_map(attentions, h, w, align_corners=False):
+    """
+
+    Args:
+        attentions: shape [B, num_head, H*W, groups]
+        h:
+        w:
+
+    Returns:
+
+        attentions: shape [B, num_head, h, w, groups]
+
+
+    """
+    scale = (h * w // attentions.shape[2])**0.5
+    if h > w:
+        w_featmap = w // int(np.round(scale))
+        h_featmap = attentions.shape[2] // w_featmap
+    else:
+        h_featmap = h // int(np.round(scale))
+        w_featmap = attentions.shape[2] // h_featmap
+    assert attentions.shape[
+        2] == h_featmap * w_featmap, f'{attentions.shape[2]} = {h_featmap} x {w_featmap}, h={h}, w={w}'
+
+    bs = attentions.shape[0]
+    nh = attentions.shape[1]  # number of head
+    groups = attentions.shape[3]  # number of group token
+    # [bs, nh, h*w, groups] -> [bs*nh, groups, h, w]
+    attentions = rearrange(
+        attentions, 'bs nh (h w) c -> (bs nh) c h w', bs=bs, nh=nh, h=h_featmap, w=w_featmap, c=groups)
+    attentions = F.interpolate(attentions, size=(h, w), mode='bilinear', align_corners=align_corners)
+    #  [bs*nh, groups, h, w] -> [bs, nh, h*w, groups]
+    attentions = rearrange(attentions, '(bs nh) c h w -> bs nh h w c', bs=bs, nh=nh, h=h, w=w, c=groups)
+
+    return attentions
+
+
+def top_groups(attn_map, k):
+    """
+    Args:
+        attn_map: (B, H, W, G)
+        k: int
+
+    Return:
+        (B, H, W, k)
+    """
+
+    attn_map = attn_map.clone()
+
+    for i in range(attn_map.size(0)):
+        # [H*W, G]
+        flatten_map = rearrange(attn_map[i], 'h w g -> (h w) g')
+        kept_mat = torch.zeros(flatten_map.shape[0], device=flatten_map.device, dtype=torch.bool)
+        area_per_group = flatten_map.sum(dim=0)
+        top_group_idx = area_per_group.topk(k=k).indices.cpu().numpy().tolist()
+        for group_idx in top_group_idx:
+            kept_mat[flatten_map.argmax(dim=-1) == group_idx] = True
+        # [H, W, 2]
+        coords = torch.stack(
+            torch.meshgrid(
+                torch.arange(attn_map[i].shape[0], device=attn_map[i].device, dtype=attn_map[i].dtype),
+                torch.arange(attn_map[i].shape[1], device=attn_map[i].device, dtype=attn_map[i].dtype)),
+            dim=-1)
+        coords = rearrange(coords, 'h w c -> (h w) c')
+
+        # calculate distance between each pair of points
+        # [non_kept, kept]
+        dist_mat = torch.sum((coords[~kept_mat].unsqueeze(1) - coords[kept_mat].unsqueeze(0))**2, dim=-1)
+
+        flatten_map[~kept_mat] = flatten_map[kept_mat.nonzero(as_tuple=True)[0][dist_mat.argmin(dim=-1)]]
+
+        attn_map[i] = flatten_map.reshape_as(attn_map[i])
+
+    return attn_map
+
+
+def seg2coord(seg_map):
+    """
+    Args:
+        seg_map (np.ndarray): (H, W)
+
+    Return:
+        dict(group_id -> (x, y))
+    """
+    h, w = seg_map.shape
+    # [h ,w, 2]
+    coords = np.stack(np.meshgrid(np.arange(h), np.arange(w), indexing='ij'), axis=-1)
+    labels = np.unique(seg_map)
+    coord_map = {}
+    for label in labels:
+        coord_map[label] = coords[seg_map == label].mean(axis=0)
+    return coord_map
+
+
+class GroupViTSegInference(EncoderDecoder):
+
+    # def __init__(self, model, text_embedding, with_bg, test_cfg=dict(mode='whole', bg_thresh=.95, use_clip=False)):
+    def __init__(self, model, text_embedding, with_bg, test_cfg=dict(mode='whole', bg_thresh=.95)):
+        super(EncoderDecoder, self).__init__()
+        if not isinstance(test_cfg, mmcv.Config):
+            test_cfg = mmcv.Config(test_cfg)
+        self.test_cfg = test_cfg
+        self.model = model
+        # [N, C]
+        self.register_buffer('text_embedding', text_embedding)
+        self.with_bg = with_bg
+        self.bg_thresh = test_cfg['bg_thresh']
+        if self.with_bg:
+            self.num_classes = len(text_embedding) + 1
+        else:
+            self.num_classes = len(text_embedding)
+        self.align_corners = False
+        logger = get_logger()
+        logger.info(
+            f'Building GroupViTSegInference with {self.num_classes} classes, test_cfg={test_cfg}, with_bg={with_bg}')
+
+    def forward_train(self, img, img_metas, gt_semantic_seg):
+        raise NotImplementedError
+
+    def get_attn_maps(self, img, return_onehot=False, rescale=False):
+        """
+        Args:
+            img: [B, C, H, W]
+
+        Returns:
+            attn_maps: list[Tensor], attention map of shape [B, H, W, groups]
+        """
+        results = self.model.img_encoder(img, return_attn=True, as_dict=True)
+        
+        attn_maps = []
+        with torch.no_grad():
+            prev_attn_masks = None
+            for idx, attn_dict in enumerate(results['attn_dicts']):
+                if attn_dict is None:
+                    # changed doesn't have to be like this
+                    # assert idx == len(results['attn_dicts']) - 1, 'only last layer can be None'
+                    continue
+                
+                # [B, G, HxW]
+                # B: batch size (1), G: number of group token
+                attn_masks = attn_dict['soft']
+                # [B, nH, G, HxW] -> [B, nH, HxW, G]
+                attn_masks = rearrange(attn_masks, 'b h g n -> b h n g')
+                if prev_attn_masks is None:
+                    prev_attn_masks = attn_masks
+                else:
+                    prev_attn_masks = prev_attn_masks @ attn_masks
+
+                # [B, nH, HxW, G] -> [B, nH, H, W, G]
+                attn_maps.append(resize_attn_map(prev_attn_masks, *img.shape[-2:]))
+
+        for i in range(len(attn_maps)):
+            attn_map = attn_maps[i]
+            # [B, nh, H, W, G]
+            assert attn_map.shape[1] == 1
+            # [B, H, W, G]
+            attn_map = attn_map.squeeze(1)
+
+            if rescale:
+                attn_map = rearrange(attn_map, 'b h w g -> b g h w')
+                attn_map = F.interpolate(
+                    attn_map, size=img.shape[2:], mode='bilinear', align_corners=self.align_corners)
+                attn_map = rearrange(attn_map, 'b g h w -> b h w g')
+
+            if return_onehot:
+                # [B, H, W, G]
+                attn_map = F.one_hot(attn_map.argmax(dim=-1), num_classes=attn_map.shape[-1]).to(dtype=attn_map.dtype)
+
+            attn_maps[i] = attn_map
+        
+        return attn_maps
+
+    def encode_decode(self, img, img_metas):
+        """Encode images with backbone and decode into a semantic segmentation
+        map of the same size as input."""
+
+        assert img.shape[0] == 1, 'batch size must be 1'
+
+        # [B, C, H, W], get the last one only
+        attn_map = self.get_attn_maps(img, rescale=True)[-1]
+        # [H, W, G], select batch idx 0
+        attn_map = attn_map[0]
+        img_outs = self.model.encode_image(img, return_feat=True, as_dict=True)
+        # [B, L, C] -> [L, C]
+        grouped_img_tokens = img_outs['image_feat'].squeeze(0)
+        img_avg_feat = img_outs['image_x']        
+        # [G, C]
+        grouped_img_tokens = F.normalize(grouped_img_tokens, dim=-1)
+        img_avg_feat = F.normalize(img_avg_feat, dim=-1)
+
+        # [H, W, G]
+        onehot_attn_map = F.one_hot(attn_map.argmax(dim=-1), num_classes=attn_map.shape[-1]).to(dtype=attn_map.dtype)
+        num_fg_classes = self.text_embedding.shape[0]
+        class_offset = 1 if self.with_bg else 0
+        text_tokens = self.text_embedding
+        num_classes = num_fg_classes + class_offset
+
+        logit_scale = torch.clamp(self.model.logit_scale.exp(), max=100)
+        # [G, N]
+        group_affinity_mat = (grouped_img_tokens @ text_tokens.T) * logit_scale
+        pre_group_affinity_mat = F.softmax(group_affinity_mat, dim=-1)
+        avg_affinity_mat = (img_avg_feat @ text_tokens.T) * logit_scale
+        avg_affinity_mat = F.softmax(avg_affinity_mat, dim=-1)
+        
+        affinity_mask = torch.zeros_like(avg_affinity_mat)
+        avg_affinity_topk = avg_affinity_mat.topk(dim=-1, k=min(5, num_fg_classes))
+        affinity_mask.scatter_add_(
+            dim=-1, index=avg_affinity_topk.indices, src=torch.ones_like(avg_affinity_topk.values))
+        group_affinity_mat.masked_fill_(~affinity_mask.bool(), float('-inf'))
+        
+        group_affinity_mat = F.softmax(group_affinity_mat, dim=-1)
+
+        # TODO: check if necessary
+        group_affinity_mat *= pre_group_affinity_mat
+        pred_logits = torch.zeros(num_classes, *attn_map.shape[:2], device=img.device, dtype=img.dtype)
+
+        pred_logits[class_offset:] = rearrange(onehot_attn_map @ group_affinity_mat, 'h w c -> c h w')
+        
+        if self.with_bg:
+            bg_thresh = min(self.bg_thresh, group_affinity_mat.max().item())
+            pred_logits[0, (onehot_attn_map @ group_affinity_mat).max(dim=-1).values < bg_thresh] = 1
+
+        return pred_logits.unsqueeze(0)
+
+    def blend_result(self, img, result, palette=None, out_file=None, opacity=0.5, with_bg=False):
+        img = mmcv.imread(img)
+        img = img.copy()
+        seg = result[0]
+        if palette is None:
+            palette = self.PALETTE
+        palette = np.array(palette)
+        assert palette.shape[1] == 3, palette.shape
+        assert len(palette.shape) == 2
+        assert 0 < opacity <= 1.0
+        color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
+        for label, color in enumerate(palette):
+            color_seg[seg == label, :] = color
+        # convert to BGR
+        color_seg = color_seg[..., ::-1]
+
+        if with_bg:
+            fg_mask = seg != 0
+            img[fg_mask] = img[fg_mask] * (1 - opacity) + color_seg[fg_mask] * opacity
+        else:
+            img = img * (1 - opacity) + color_seg * opacity
+        img = img.astype(np.uint8)
+
+        if out_file is not None:
+            mmcv.imwrite(img, out_file)
+
+        return img
+
+    def show_result(self, img_show, img_tensor, result, out_file, vis_mode='pred'):
+        print('current vis mode: ', vis_mode)
+        assert vis_mode in [
+            'input', 'pred', 'input_pred', 'all_groups', 'first_group', 'final_group', 'input_pred_label'
+        ], vis_mode
+
+        if vis_mode == 'input':
+            mmcv.imwrite(img_show, out_file)
+        elif vis_mode == 'pred':
+            output = Image.fromarray(result[0].astype(np.uint8)).convert('P')
+            output.putpalette(np.array(self.PALETTE).astype(np.uint8))
+            mmcv.mkdir_or_exist(osp.dirname(out_file))
+            output.save(out_file.replace('.jpg', '.png'))
+        elif vis_mode == 'input_pred':
+            self.blend_result(img=img_show, result=result, out_file=out_file, opacity=0.5, with_bg=self.with_bg)
+        elif vis_mode == 'input_pred_label':
+            labels = np.unique(result[0])
+            coord_map = seg2coord(result[0])
+            # reference: https://github.com/open-mmlab/mmdetection/blob/ff9bc39913cb3ff5dde79d3933add7dc2561bab7/mmdet/models/detectors/base.py#L271 # noqa
+            blended_img = self.blend_result(
+                img=img_show, result=result, out_file=None, opacity=0.5, with_bg=self.with_bg)
+            blended_img = mmcv.bgr2rgb(blended_img)
+            width, height = img_show.shape[1], img_show.shape[0]
+            EPS = 1e-2
+            fig = plt.figure(frameon=False)
+            canvas = fig.canvas
+            dpi = fig.get_dpi()
+            fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)
+
+            # remove white edges by set subplot margin
+            plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
+            ax = plt.gca()
+            ax.axis('off')
+            for i, label in enumerate(labels):
+                if self.with_bg and label == 0:
+                    continue
+                center = coord_map[label].astype(np.int32)
+                label_text = self.CLASSES[label]
+                ax.text(
+                    center[1],
+                    center[0],
+                    f'{label_text}',
+                    bbox={
+                        'facecolor': 'black',
+                        'alpha': 0.5,
+                        'pad': 0.7,
+                        'edgecolor': 'none'
+                    },
+                    color='orangered',
+                    fontsize=16,
+                    verticalalignment='top',
+                    horizontalalignment='left')
+            plt.imshow(blended_img)
+            stream, _ = canvas.print_to_buffer()
+            buffer = np.frombuffer(stream, dtype='uint8')
+            img_rgba = buffer.reshape(height, width, 4)
+            rgb, alpha = np.split(img_rgba, [3], axis=2)
+            img = rgb.astype('uint8')
+            img = mmcv.rgb2bgr(img)
+            mmcv.imwrite(img, out_file)
+            plt.close()
+
+        elif vis_mode == 'all_groups' or vis_mode == 'final_group' or vis_mode == 'first_group':
+            attn_map_list = self.get_attn_maps(img_tensor)
+            assert len(attn_map_list) in [1, 2]
+            # only show 16 groups for the first stage
+            # if len(attn_map_list) == 2:
+            #     attn_map_list[0] = top_groups(attn_map_list[0], k=16)
+
+            num_groups = [attn_map_list[layer_idx].shape[-1] for layer_idx in range(len(attn_map_list))]
+            for layer_idx, attn_map in enumerate(attn_map_list):
+                if vis_mode == 'first_group' and layer_idx != 0:
+                    continue
+                if vis_mode == 'final_group' and layer_idx != len(attn_map_list) - 1:
+                    continue
+                attn_map = rearrange(attn_map, 'b h w g -> b g h w')
+                attn_map = F.interpolate(
+                    attn_map, size=img_show.shape[:2], mode='bilinear', align_corners=self.align_corners)
+                group_result = attn_map.argmax(dim=1).cpu().numpy()
+                if vis_mode == 'all_groups':
+                    layer_out_file = out_file.replace(
+                        osp.splitext(out_file)[-1], f'_layer{layer_idx}{osp.splitext(out_file)[-1]}')
+                else:
+                    layer_out_file = out_file
+                self.blend_result(
+                    img=img_show,
+                    result=group_result,
+                    palette=GROUP_PALETTE[sum(num_groups[:layer_idx]):sum(num_groups[:layer_idx + 1])],
+                    out_file=layer_out_file,
+                    opacity=0.5)
+        else:
+            raise ValueError(f'Unknown vis_type: {vis_mode}')
+        

+ 32 - 0
setup.cfg

@@ -0,0 +1,32 @@
+[yapf]
+based_on_style = pep8
+blank_line_before_nested_class_or_def = true
+split_before_expression_after_opening_paren = true
+column_limit = 120
+
+[isort]
+line_length = 120
+multi_line_output = 0
+known_standard_library = setuptools
+known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,packaging,prettytable,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,torch,ts
+no_lines_before = STDLIB,LOCALFOLDER
+default_section = THIRDPARTY
+
+[codespell]
+skip = *.po,*.ts,*.ipynb
+count =
+quiet-level = 3
+ignore-words-list = formating,sur,hist
+
+[flake8]
+ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811
+max-line-length = 120
+max-complexity = 18
+select = B,C,E,F,W,T4,B9
+exclude = build
+per-file-ignores =
+  **/__init__.py:F401,F403,E402
+  **/configs/**.py:F401,E402
+  configs/**.py:F401,E402
+  **/tests/config/**.py:F401,E402
+  tests/config/**.py:F401,E402

+ 5 - 0
tools/debug.sh

@@ -0,0 +1,5 @@
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+srun -p video3 -N1 -n2 --gres=gpu:2 --job-name=4m --quotatype=auto --cpus-per-task=12 \
+python -u -m main_pretrain \
+    --cfg configs/ovsegmentor/ovsegmentor_pretrain_vit_bert_stage2.yml \
+    --amp-opt-level O0 

+ 3 - 0
tools/run.sh

@@ -0,0 +1,3 @@
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+python -m torch.distributed.launch --nproc_per_node=8 --master_port=29500 \
+    main_pretrain.py --cfg configs/ovsegmentor/ovsegmentor_pretrain_vit_bert_stage1.yml

+ 4 - 0
tools/run_slurm.sh

@@ -0,0 +1,4 @@
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+srun -p video -N1 -n8 --gres=gpu:8 --job-name=4m --quotatype=auto --cpus-per-task=12 \
+python -u -m main_pretrain \
+    --cfg configs/ovsegmentor/ovsegmentor_pretrain_vit_bert_stage1.yml

+ 4 - 0
tools/run_slurm_stage2.sh

@@ -0,0 +1,4 @@
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+srun -p video -N1 -n8 --gres=gpu:8 --job-name=4m_stage2 --quotatype=auto --cpus-per-task=12 \
+python -u -m main_pretrain \
+    --cfg configs/ovsegmentor/ovsegmentor_pretrain_vit_bert_stage2.yml

+ 3 - 0
tools/run_stage2.sh

@@ -0,0 +1,3 @@
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+python -m torch.distributed.launch --nproc_per_node=8 --master_port=29500 \
+    main_pretrain.py --cfg configs/ovsegmentor/ovsegmentor_pretrain_vit_bert_stage2.yml

+ 5 - 0
tools/test_ade20k.sh

@@ -0,0 +1,5 @@
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+srun -p video -N1 -n1 --gres=gpu:1 --job-name=test_ade20k --quotatype=auto --cpus-per-task=10 \
+python -u -m main_seg \
+    --cfg configs/test_ade20k.yml \
+    --resume /mnt/petrelfs/xujilan/exps/cc12m_100/test_voc12_bs256x1/best_miou.pth \

+ 5 - 0
tools/test_coco.sh

@@ -0,0 +1,5 @@
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+srun -p video -N1 -n1 --gres=gpu:1 --job-name=test_coco --quotatype=auto --cpus-per-task=10 \
+python -u -m main_seg \
+    --cfg configs/test_coco.yml \
+    --resume /mnt/petrelfs/xujilan/exps/cc12m_100/test_voc12_bs256x1/best_miou.pth \

+ 5 - 0
tools/test_context.sh

@@ -0,0 +1,5 @@
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+srun -p video -n1 -N1 --gres=gpu:1 --job-name=test_voc_context --quotatype=auto --cpus-per-task=10 \
+python -u -m main_seg \
+    --cfg configs/test_voc_context.yml \
+    --resume /mnt/petrelfs/xujilan/exps/cc12m_100/test_voc12_bs256x1/best_miou.pth \

+ 5 - 0
tools/test_voc12.sh

@@ -0,0 +1,5 @@
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+srun -p video -N1 -n1 --gres=gpu:1 --job-name=test_voc12 --quotatype=auto --cpus-per-task=10 -x SH-IDC1-10-140-24-19 \
+python -u -m main_seg \
+    --cfg configs/test_voc12.yml \
+    --resume /mnt/petrelfs/xujilan/exps/cc12m_100/test_voc12_bs256x1/best_miou.pth \

+ 28 - 0
utils/__init__.py

@@ -0,0 +1,28 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+
+from .checkpoint import auto_resume_helper, load_checkpoint, save_checkpoint, load_checkpoint_stage1
+from .config import get_config
+from .logger import get_logger
+from .lr_scheduler import build_scheduler
+from .misc import build_dataset_class_tokens, build_dataset_class_lists, data2cuda, get_batch_size, get_grad_norm, parse_losses, reduce_tensor, momentum_update
+from .optimizer import build_optimizer
+from .misc import cdist_
+__all__ = [
+    'get_config', 'get_logger', 'build_optimizer', 'build_scheduler', 'load_checkpoint', 'save_checkpoint',
+    'auto_resume_helper', 'reduce_tensor', 'get_grad_norm', 'get_batch_size', 'data2cuda', 'parse_losses','momentum_update',
+    'build_dataset_class_tokens', 'build_dataset_class_lists','cdist_', 
+]

+ 186 - 0
utils/checkpoint.py

@@ -0,0 +1,186 @@
+# -------------------------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE
+#
+# Written by Ze Liu, Zhenda Xie
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+
+import os
+from collections import defaultdict
+
+import torch
+import torch.distributed as dist
+from mmcv.runner import CheckpointLoader
+from omegaconf import read_write
+
+from .logger import get_logger
+from ipdb import set_trace
+
+try:
+    # noinspection PyUnresolvedReferences
+    from apex import amp
+except ImportError:
+    amp = None
+
+
+def load_checkpoint_stage1(config, model):
+    logger = get_logger()
+    logger.info(f'==============> Resuming stage1 checkpoint from {config.checkpoint.resume}....................')
+    checkpoint = CheckpointLoader.load_checkpoint(config.checkpoint.stage1_checkpoint, map_location='cpu')
+    ### load online model parameters ###
+    # msg = model.load_state_dict(checkpoint['model'], strict=False)
+    new_state_dict = {}
+    new_params = ['logit_scale_mask']
+    for k, v in model.state_dict().items():
+        if k in new_params:
+            continue
+        if k in checkpoint['model']:
+            new_state_dict[k] = checkpoint['model'][k]
+        else:
+            oldk = k.replace('img_encoder_momentum', 'img_encoder')
+            # new_state_dict[k] = checkpoint['model'][oldk]
+            if oldk in checkpoint['model']:
+               new_state_dict[k] = checkpoint['model'][oldk]
+    
+    msg = model.load_state_dict(new_state_dict, strict=False)
+    logger.info(msg)
+    
+    del checkpoint
+    torch.cuda.empty_cache()
+
+def load_checkpoint(config, model, optimizer, lr_scheduler):
+    logger = get_logger()
+    logger.info(f'==============> Resuming from {config.checkpoint.resume}....................')
+    checkpoint = CheckpointLoader.load_checkpoint(config.checkpoint.resume, map_location='cpu')
+    msg = model.load_state_dict(checkpoint['model'], strict=False)
+    logger.info(msg)
+    metrics = defaultdict(float)
+    if (not config.evaluate.eval_only and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint
+            and 'epoch' in checkpoint):
+        optimizer.load_state_dict(checkpoint['optimizer'])
+        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+        with read_write(config):
+            config.train.start_epoch = checkpoint['epoch'] + 1
+        if 'amp' in checkpoint and config.train.amp_opt_level != 'O0' and checkpoint[
+                'config'].train.amp_opt_level != 'O0':
+            amp.load_state_dict(checkpoint['amp'])
+        logger.info(f"=> loaded successfully '{config.checkpoint.resume}' (epoch {checkpoint['epoch']})")
+        metrics = checkpoint['metrics']
+
+    del checkpoint
+    torch.cuda.empty_cache()
+    return metrics
+
+
+def save_checkpoint(config, epoch, model, metrics, optimizer, lr_scheduler, suffix=''):
+    save_state = {
+        'model': model.state_dict(),
+        'optimizer': optimizer.state_dict(),
+        'lr_scheduler': lr_scheduler.state_dict(),
+        'metrics': metrics,
+        'epoch': epoch,
+        'config': config
+    }
+    logger = get_logger()
+
+    for k, v in metrics.items():
+        save_state[k] = v
+
+    if config.train.amp_opt_level != 'O0':
+        save_state['amp'] = amp.state_dict()
+
+    if len(suffix) > 0 and not suffix.startswith('_') and suffix != 'best_miou':
+        suffix = '_' + suffix
+    
+    if epoch >= 10 and epoch % 10 == 0 and suffix != 'best_miou':
+        filename = f'ckpt_epoch_{epoch}{suffix}.pth'
+        save_path = os.path.join(config.output, filename)
+        torch.save(save_state, save_path)
+
+    ##### this is for per epoch saving, easy for resuming #####    
+    # filename = f'ckpt_epoch_{suffix}.pth' # only save the best one
+    # save_path = os.path.join(config.output, filename)
+    # logger.info(f'{save_path} saving......')
+    if suffix == 'best_miou':
+        print('saving best iou checkpoint')
+        filename = 'best_miou.pth' # only save the best one
+        current_save_path = os.path.join(config.output, filename)
+        torch.save(save_state, current_save_path)    
+        logger.info(f'{current_save_path} saved for best iou!!!')
+    else:
+        current_save_path = os.path.join(config.output, 'checkpoint.pth')
+        torch.save(save_state, current_save_path)
+        logger.info(f'{current_save_path} saved !!!')
+
+    # if config.checkpoint.max_kept > 0:
+    #     if epoch >= config.checkpoint.max_kept:
+    #         logger.info(f'Epoch: {epoch}, greater than config.checkpoint.max_kept: {config.checkpoint.max_kept}')
+    #         end_clean_epoch = epoch - config.checkpoint.max_kept
+    #         old_path_list = []
+    #         for cur_clean_epoch in range(end_clean_epoch + 1):
+    #             old_path = os.path.join(config.output, f'ckpt_epoch_{cur_clean_epoch}{suffix}.pth')
+    #             if os.path.exists(old_path):
+    #                 logger.info(f'old checkpoint path {old_path} exits')
+    #                 old_path_list.append(old_path)
+    #         for old_path in old_path_list[:-config.checkpoint.max_kept]:
+    #             os.remove(old_path)
+    #             logger.info(f'old checkpoint path {old_path} removed!!!')
+
+def get_grad_norm(parameters, norm_type=2):
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = list(filter(lambda p: p.grad is not None, parameters))
+    norm_type = float(norm_type)
+    total_norm = 0
+    for p in parameters:
+        param_norm = p.grad.data.norm(norm_type)
+        total_norm += param_norm.item()**norm_type
+    total_norm = total_norm**(1. / norm_type)
+    return total_norm
+
+
+def auto_resume_helper(output_dir):
+    if os.path.exists(os.path.join(output_dir, 'checkpoint.pth')):
+        return os.path.join(output_dir, 'checkpoint.pth')
+
+    checkpoints = os.listdir(output_dir)
+    checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
+    print(f'All checkpoints founded in {output_dir}: {checkpoints}')
+    if len(checkpoints) > 0:
+        latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
+        print(f'The latest checkpoint founded: {latest_checkpoint}')
+        resume_file = latest_checkpoint
+    else:
+        resume_file = None
+    return resume_file
+
+
+def reduce_tensor(tensor):
+    rt = tensor.clone()
+    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
+    rt /= dist.get_world_size()
+    return rt

+ 79 - 0
utils/config.py

@@ -0,0 +1,79 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import os
+import os.path as osp
+
+from omegaconf import OmegaConf
+
+
+def load_config(cfg_file):
+    cfg = OmegaConf.load(cfg_file)
+    if '_base_' in cfg:
+        if isinstance(cfg._base_, str):
+            base_cfg = OmegaConf.load(osp.join(osp.dirname(cfg_file), cfg._base_))
+        else:
+            base_cfg = OmegaConf.merge(OmegaConf.load(f) for f in cfg._base_)
+        cfg = OmegaConf.merge(base_cfg, cfg)
+    return cfg
+
+
+def get_config(args):
+    cfg = load_config(args.cfg)
+    OmegaConf.set_struct(cfg, True)
+
+    if args.opts is not None:
+        cfg = OmegaConf.merge(cfg, OmegaConf.from_dotlist(args.opts))
+    if hasattr(args, 'batch_size') and args.batch_size:
+        cfg.data.batch_size = args.batch_size
+
+    if hasattr(args, 'amp_opt_level') and args.amp_opt_level:
+        cfg.train.amp_opt_level = args.amp_opt_level
+
+    if hasattr(args, 'resume') and args.resume:
+        cfg.checkpoint.resume = args.resume
+
+    if hasattr(args, 'eval') and args.eval:
+        cfg.evaluate.eval_only = args.eval
+
+    if hasattr(args, 'keep') and args.keep:
+        cfg.checkpoint.max_kept = args.keep
+
+    if not cfg.model_name:
+        cfg.model_name = osp.splitext(osp.basename(args.cfg))[0]
+
+    world_size = int(os.environ.get('WORLD_SIZE', 1))
+    cfg.model_name = cfg.model_name + f'_bs{cfg.data.batch_size}x{world_size}'
+
+    # if hasattr(args, 'output') and args.output:
+        # cfg.output = args.output
+    if hasattr(args, 'output'):
+        cfg.output = osp.join(cfg.output, cfg.model_name)
+    else:
+        cfg.output = osp.join('output', cfg.model_name)
+
+    if hasattr(args, 'tag') and args.tag:
+        cfg.tag = args.tag
+        cfg.output = osp.join(cfg.output, cfg.tag)
+
+    if hasattr(args, 'wandb') and args.wandb:
+        cfg.wandb = args.wandb
+
+    if hasattr(args, 'vis') and args.vis:
+        cfg.vis = args.vis
+
+    cfg.local_rank = args.local_rank
+
+    OmegaConf.set_readonly(cfg, True)
+
+    return cfg

+ 61 - 0
utils/logger.py

@@ -0,0 +1,61 @@
+# -------------------------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE
+#
+# Written by Ze Liu, Zhenda Xie
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import logging
+import os.path as osp
+
+from mmcv.utils import get_logger as get_root_logger
+from termcolor import colored
+
+logger_name = None
+
+
+def get_logger(cfg=None, log_level=logging.INFO):
+    global logger_name
+    if cfg is None:
+        return get_root_logger(logger_name)
+
+    # creating logger
+    name = cfg.model_name
+    output = cfg.output
+    logger_name = name
+
+    logger = get_root_logger(name, osp.join(output, 'log.txt'), log_level=log_level, file_mode='a')
+
+    fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
+    color_fmt = colored('[%(asctime)s %(name)s]', 'green') \
+        + colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
+
+    for handler in logger.handlers:
+        if isinstance(handler, logging.StreamHandler):
+            handler.setFormatter(logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
+
+        if isinstance(handler, logging.FileHandler):
+            handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
+
+    return logger

+ 36 - 0
utils/lr_scheduler.py

@@ -0,0 +1,36 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from timm.scheduler.cosine_lr import CosineLRScheduler
+
+
+def build_scheduler(config, optimizer, n_iter_per_epoch):
+    num_steps = int(config.epochs * n_iter_per_epoch)
+    warmup_steps = int(config.warmup_epochs * n_iter_per_epoch)
+
+    lr_scheduler = None
+    if config.lr_scheduler.name == 'cosine':
+        lr_scheduler = CosineLRScheduler(
+            optimizer,
+            t_initial=num_steps,
+            # t_mul=1.,  ## this does not work with higher versions of timm
+            lr_min=config.min_lr,
+            warmup_lr_init=config.warmup_lr,
+            warmup_t=warmup_steps,
+            cycle_limit=1,
+            t_in_epochs=False,
+        )
+    else:
+        raise NotImplementedError(f'lr scheduler {config.lr_scheduler.name} not implemented')
+
+    return lr_scheduler

+ 126 - 0
utils/misc.py

@@ -0,0 +1,126 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+# Modified by Jilan Xu
+# -------------------------------------------------------------------------
+
+import collections.abc
+from collections import OrderedDict
+import json
+import cv2
+import numpy as np
+from PIL import Image
+from ipdb import set_trace
+import scipy
+
+import torch
+import torch.distributed as dist
+from datasets import template_meta
+
+def reduce_tensor(tensor):
+    rt = tensor.clone()
+    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
+    rt /= dist.get_world_size()
+    return rt
+
+
+def momentum_update(model_on, model_off, coeff):
+    for params_on, params_off in zip(model_on.parameters(), model_off.parameters()):
+        params_off.data = coeff * params_off.data + (1 - coeff) * params_on.data
+
+def get_grad_norm(parameters, norm_type=2):
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = list(filter(lambda p: p.grad is not None, parameters))
+    norm_type = float(norm_type)
+    total_norm = 0
+    for p in parameters:
+        param_norm = p.grad.data.norm(norm_type)
+        total_norm += param_norm.item()**norm_type
+    total_norm = total_norm**(1. / norm_type)
+    return total_norm
+
+
+def get_batch_size(data):
+
+    if isinstance(data, torch.Tensor):
+        return data.size(0)
+    elif isinstance(data, collections.abc.Mapping):
+        return get_batch_size(data[next(iter(data))])
+    elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
+        # check to make sure that the elements in batch have consistent size
+        it = iter(data)
+        return get_batch_size(next(it))
+
+    raise TypeError
+
+
+def data2cuda(data):
+
+    if isinstance(data, torch.Tensor):
+        batch = data.cuda(non_blocking=True)
+        return batch
+    elif isinstance(data, collections.abc.Mapping):
+        return {key: data2cuda(data[key]) for key in data}
+    elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
+        return [data2cuda(d) for d in data]
+    else:
+        raise TypeError
+
+def parse_losses(losses):
+    log_vars = OrderedDict()
+    for loss_name, loss_value in losses.items():
+        if isinstance(loss_value, torch.Tensor):
+            log_vars[loss_name] = loss_value.mean()
+        elif isinstance(loss_value, list):
+            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+        else:
+            raise TypeError(f'{loss_name} is not a tensor or list of tensors')
+
+    loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
+    return loss, log_vars
+
+
+def build_dataset_class_tokens(text_transform, template_set, classnames):
+
+    tokens = []
+    templates = template_meta[template_set]
+    for classname in classnames:
+        # format with class
+        tokens.append(torch.stack([text_transform(template.format(classname)) for template in templates]))
+    # [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
+    tokens = torch.stack(tokens)
+
+    return tokens
+
+def build_dataset_class_lists(template_set, classnames):
+    tokens = []
+    templates = template_meta[template_set]
+    for classname in classnames:
+        # format with class
+        for template in templates:
+            tokens.append(template.format(classname))
+    # [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
+    # tokens = torch.stack(tokens)
+    return tokens
+            
+def cdist_(x, metric='euclidean'):
+    assert len(x.shape) == 3
+    if metric != 'JS':
+        x_ = torch.split(x, 1, dim=0)  # tuple
+        return np.mean(tuple(map(lambda a: scipy.spatial.distance.cdist(a.squeeze(), a.squeeze(), metric).mean(), x_)))
+    else:
+        softmax = torch.nn.Softmax(dim=1)
+        x_ = torch.split(softmax(x), 1, dim=0)  # tuple
+        return np.mean(tuple(map(lambda a: scipy.spatial.distance.cdist(a.squeeze(), a.squeeze(), metric).mean(), x_)))
+        pass
+    pass

+ 72 - 0
utils/optimizer.py

@@ -0,0 +1,72 @@
+# -------------------------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE
+#
+# Written by Ze Liu, Zhenda Xie
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from torch import optim as optim
+
+
+def build_optimizer(config, model):
+    """Build optimizer, set weight decay of normalization to 0 by default."""
+    parameters = set_weight_decay(model, {}, {})
+
+    opt_name = config.optimizer.name
+    optimizer = None
+    if opt_name == 'adamw':
+        optimizer = optim.AdamW(
+            parameters,
+            eps=config.optimizer.eps,
+            betas=config.optimizer.betas,
+            lr=config.base_lr,
+            weight_decay=config.weight_decay)
+    else:
+        raise ValueError(f'Unsupported optimizer: {opt_name}')
+
+    return optimizer
+
+
+def set_weight_decay(model, skip_list=(), skip_keywords=()):
+    has_decay = []
+    no_decay = []
+
+    for name, param in model.named_parameters():
+        if not param.requires_grad:
+            continue  # frozen weights
+        if len(param.shape) == 1 or name.endswith('.bias') or (name in skip_list) or \
+                check_keywords_in_name(name, skip_keywords):
+            no_decay.append(param)
+            # print(f"{name} has no weight decay")
+        else:
+            has_decay.append(param)
+    return [{'params': has_decay}, {'params': no_decay, 'weight_decay': 0.}]
+
+
+def check_keywords_in_name(name, keywords=()):
+    isin = False
+    for keyword in keywords:
+        if keyword in name:
+            isin = True
+    return isin