Procházet zdrojové kódy

Added a first version.

Shalini De Mello před 3 roky
rodič
revize
d6ee6ce211
56 změnil soubory, kde provedl 5737 přidání a 0 odebrání
  1. 332 0
      README.md
  2. 109 0
      configs/default.yml
  3. 29 0
      configs/group_vit_gcc_redcap_30e.yml
  4. 26 0
      configs/group_vit_gcc_yfcc_30e.yml
  5. 282 0
      convert_dataset/convert_coco_object.py
  6. 119 0
      convert_dataset/convert_yfcc14m.py
  7. 114 0
      convert_dataset/create_subset.py
  8. 71 0
      convert_dataset/process_redcaps.py
  9. 20 0
      datasets/__init__.py
  10. binární
      datasets/bpe_simple_vocab_16e6.txt.gz
  11. 337 0
      datasets/builder.py
  12. 36 0
      datasets/formatting.py
  13. 267 0
      datasets/imagenet_template.py
  14. 160 0
      datasets/tokenizer.py
  15. 142 0
      demo/demo_seg.py
  16. binární
      demo/examples/coco.jpg
  17. binární
      demo/examples/ctx.jpg
  18. binární
      demo/examples/voc.jpg
  19. binární
      figs/github_arch.gif
  20. binární
      figs/github_coco.gif
  21. binární
      figs/github_ctx.gif
  22. binární
      figs/github_voc.gif
  23. 460 0
      main_group_vit.py
  24. 194 0
      main_seg.py
  25. 19 0
      models/__init__.py
  26. 24 0
      models/builder.py
  27. 882 0
      models/group_vit.py
  28. 74 0
      models/misc.py
  29. 302 0
      models/multi_label_contrastive.py
  30. 117 0
      models/transformer.py
  31. 59 0
      models/utils.py
  32. binární
      segmentation/.DS_Store
  33. binární
      segmentation/configs/.DS_Store
  34. binární
      segmentation/configs/_base_/.DS_Store
  35. 15 0
      segmentation/configs/_base_/custom_import.py
  36. 44 0
      segmentation/configs/_base_/datasets/coco.py
  37. 43 0
      segmentation/configs/_base_/datasets/pascal_context.py
  38. 43 0
      segmentation/configs/_base_/datasets/pascal_voc12.py
  39. 18 0
      segmentation/datasets/__init__.py
  40. 48 0
      segmentation/datasets/coco_object.py
  41. 26 0
      segmentation/datasets/pascal_context.py
  42. 22 0
      segmentation/datasets/pascal_voc.py
  43. 20 0
      segmentation/evaluation/__init__.py
  44. 109 0
      segmentation/evaluation/builder.py
  45. 209 0
      segmentation/evaluation/group_palette.txt
  46. 370 0
      segmentation/evaluation/group_vit_seg.py
  47. 32 0
      setup.cfg
  48. 24 0
      tools/dist_launch.sh
  49. 29 0
      tools/dist_mn_launch.sh
  50. 25 0
      utils/__init__.py
  51. 145 0
      utils/checkpoint.py
  52. 77 0
      utils/config.py
  53. 61 0
      utils/logger.py
  54. 36 0
      utils/lr_scheduler.py
  55. 94 0
      utils/misc.py
  56. 72 0
      utils/optimizer.py

+ 332 - 0
README.md

@@ -0,0 +1,332 @@
+# GroupViT: Semantic Segmentation Emerges from Text Supervision
+
+This repository is the official implementation for GroupViT introduced in the paper:
+
+[**GroupViT: Semantic Segmentation Emerges from Text Supervision**](https://arxiv.org/abs/2202.11094)
+<br>
+[*Jiarui Xu*](https://jerryxu.net),
+[*Shalini De Mello*](https://research.nvidia.com/person/shalini-gupta),
+[*Wonmin Byeon*](https://wonmin-byeon.github.io/),
+[*Thomas Breuel*](http://www.tmbdev.net/),
+[*Jan Kautz*](https://research.nvidia.com/person/jan-kautz),
+[*Xiaolong Wang*](https://xiaolonw.github.io/)
+<br>
+CVPR 2022
+
+The project page with examples is at [https://jerryxu.net/GroupViT/](https://jerryxu.net/GroupViT/).
+
+<div align="center">
+<img src="figs/github_arch.gif" width="85%">
+</div>
+
+## Citation
+
+If you find our work useful in your research, please cite:
+
+```latex
+@article{xu2022groupvit,
+  author    = {Xu, Jiarui and De Mello, Shalini and Liu, Sifei and Byeon, Wonmin and Breuel, Thomas and Kautz, Jan and Wang, Xiaolong},
+  title     = {GroupViT: Semantic Segmentation Emerges from Text Supervision},
+  journal   = {arXiv preprint arXiv:2202.11094},
+  year      = {2022},
+}
+```
+
+## Environmental Setup
+
+* Python 3.7
+* PyTorch 1.8
+* webdataset 0.1.103
+* mmsegmentation 0.18.0
+* timm 0.4.12
+
+Quick start full script:
+
+```shell
+conda create -n groupvit python=3.7 -y
+conda activate groupvit
+conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 -c pytorch -c conda-forge
+pip install mmcv-full==1.3.14 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html
+pip install mmsegmentation==0.18.0
+pip install webdataset==0.1.103
+pip install timm==0.4.12
+git clone https://github.com/NVIDIA/apex
+cd && apex && pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
+pip install opencv-python==4.4.0.46 termcolor==1.1.0 diffdist einops omegaconf
+pip install nltk ftfy regex tqdm
+```
+
+## Demo
+
+Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the web demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/xvjiarui/GroupViT)
+
+Run demo on Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pdJVfAZUchMiHCraA_qBwAs4xnt1ekIU)
+
+To run demo from command line:
+
+```shell
+python demo/demo_seg.py --cfg configs/group_vit_gcc_yfcc_30e.yml --resume /path/to/checkpoint --vis input_pred_label final_group --input demo/examples/voc.jpg --output_dir demo/output
+```
+The output is saved in `demo/output/`.
+
+## Benchmark
+
+<table>
+<thead>
+  <tr>
+    <th></th>
+    <th>Zero-shot Classification</th>
+    <th colspan="3">Zero-shot Segmentation</th>
+  </tr>
+</thead>
+<tbody>
+  <tr>
+    <td>config</td>
+    <td>ImageNet</td>
+    <td>Pascal VOC</td>
+    <td>Pascal Context</td>
+    <td>COCO</td>
+  </tr>
+  <tr>
+    <td><a href="configs/group_vit_gcc_yfcc_30e.yml">cfg</a></td>
+    <td>43.7</td>
+    <td>52.3</td>
+    <td>22.4</td>
+    <td>24.3</td>
+  </tr>
+  <tr>
+    <td><a href="configs/group_vit_gcc_redcap_30e.yml">cfg</a></td>
+    <td>51.6</td>
+    <td>50.8</td>
+    <td>23.7</td>
+    <td>27.5</td>
+  </tr>
+</tbody>
+</table>
+
+You may download pre-trained weights `group_vit_gcc_yfcc_30e-879422e0.pth` and `group_vit_gcc_redcap_30e-3dd09a76.pth` from [Jiarui Xu's Github](https://github.com/xvjiarui/GroupViT#benchmark).
+
+<div align="center">
+<img src="figs/github_voc.gif" width="32%">
+<img src="figs/github_ctx.gif" width="32%">
+<img src="figs/github_coco.gif" width="32%">
+</div>
+
+<details><summary>Zero-shot Transfer to Classification on ImageNet</summary> <pre><code>./tools/dist_launch.sh main_group_vit.py /path/to/config 8 --resume /path/to/checkpoint --eval</code></pre> </details>
+<details><summary>Zero-shot Transfer to Semantic Segmentation on Pascal VOC</summary><pre><code>./tools/dist_launch.sh main_seg.py /path/to/config 8 --resume /path/to/checkpoint</code></pre></details>
+<details><summary>Zero-shot Transfer to Semantic Segmentation on Pascal Context</summary><pre><code>./tools/dist_launch.sh main_seg.py /path/to/config 8 --resume /path/to/checkpoint --opts evaluate.seg.cfg=segmentation/configs/_base_/datasets/pascal_context.py</code></pre></details>
+<details><summary>Zero-shot Transfer to Semantic Segmentation on COCO</summary><pre><code>./tools/dist_launch.sh main_seg.py /path/to/config 8 --resume /path/to/checkpoint --opts evaluate.seg.cfg=segmentation/configs/_base_/datasets/coco.py</code></pre></details>
+
+## Data Preparation
+
+During training, we use [webdataset](https://webdataset.github.io/webdataset/) for scalable data loading.
+To convert image text pairs into webdataset format, we use the [img2dataset](https://github.com/rom1504/img2dataset) tool to download and preprocess the dataset.
+
+For inference, we use [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) for semantic segmentation testing, evaluation and visualization on Pascal VOC, Pascal Context and COCO datasets.
+
+The overall file structure is as follows:
+
+```shell
+GroupViT
+├── local_data
+│   ├── gcc3m_shards
+│   │   ├── gcc-train-000000.tar
+│   │   ├── ...
+│   │   ├── gcc-train-000436.tar
+│   ├── gcc12m_shards
+│   │   ├── gcc-conceptual-12m-000000.tar
+│   │   ├── ...
+│   │   ├── gcc-conceptual-12m-001943.tar
+│   ├── yfcc14m_shards
+│   │   ├── yfcc14m-000000.tar
+│   │   ├── ...
+│   │   ├── yfcc14m-001888.tar
+│   ├── redcap12m_shards
+│   │   ├── redcap12m-000000.tar
+│   │   ├── ...
+│   │   ├── redcap12m-001211.tar
+│   ├── imagenet_shards
+│   │   ├── imagenet-val-000000.tar
+│   │   ├── ...
+│   │   ├── imagenet-val-000049.tar
+│   ├── VOCdevkit
+│   │   ├── VOC2012
+│   │   │   ├── JPEGImages
+│   │   │   ├── SegmentationClass
+│   │   │   ├── ImageSets
+│   │   │   │   ├── Segmentation
+│   │   ├── VOC2010
+│   │   │   ├── JPEGImages
+│   │   │   ├── SegmentationClassContext
+│   │   │   ├── ImageSets
+│   │   │   │   ├── SegmentationContext
+│   │   │   │   │   ├── train.txt
+│   │   │   │   │   ├── val.txt
+│   │   │   ├── trainval_merged.json
+│   │   ├── VOCaug
+│   │   │   ├── dataset
+│   │   │   │   ├── cls
+│   ├── coco
+│   │   ├── images
+│   │   │   ├── train2017
+│   │   │   ├── val2017
+│   │   ├── annotations
+│   │   │   ├── train2017
+│   │   │   ├── val2017
+```
+
+The instructions for preparing each dataset are as followed.
+
+### GCC3M
+
+Please download the training split annotation file from [Conceptual Caption 12M](https://ai.google.com/research/ConceptualCaptions/download) and name it to `gcc3m.tsv`.
+
+Then run `img2dataset` to download the image text pairs and save in webdataset format.
+```
+sed -i '1s/^/caption\turl\n/' gcc3m.tsv
+img2dataset --url_list gcc3m.tsv --input_format "tsv" \
+            --url_col "url" --caption_col "caption" --output_format webdataset\
+            --output_folder local_data/gcc3m_shards
+            --processes_count 16 --thread_count 64
+            --image_size 512 --resize_mode keep_ratio --resize_only_if_bigger True \
+            --enable_wandb True --save_metadata False --oom_shard_count 6
+rename -d 's/^/gcc-train-/' local_data/gcc3m_shards/*
+```
+Please refer to [img2dataset CC3M tutorial](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md) for details.
+
+### GCC12M
+
+Please download the annotation file from [Conceptual Caption 12M](https://github.com/google-research-datasets/conceptual-12m) and name it to `gcc12m.tsv`.
+
+Then run `img2dataset` to download the image text pairs and save in webdataset format.
+```
+sed -i '1s/^/caption\turl\n/' gcc12m.tsv
+img2dataset --url_list gcc12m.tsv --input_format "tsv" \
+            --url_col "url" --caption_col "caption" --output_format webdataset\
+            --output_folder local_data/gcc12m_shards \
+            --processes_count 16 --thread_count 64
+            --image_size 512 --resize_mode keep_ratio --resize_only_if_bigger True \
+            --enable_wandb True --save_metadata False --oom_shard_count 6
+rename -d 's/^/gcc-conceptual-12m-/' local_data/gcc12m_shards/*
+```
+Please refer to [img2dataset CC12M tutorial](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc12m.md) for details.
+
+### YFCC14M
+Please run following [CLIP Data Preparation](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md) to download YFCC14M subset.
+```
+wget https://openaipublic.azureedge.net/clip/data/yfcc100m_subset_data.tsv.bz2
+bunzip2 yfcc100m_subset_data.tsv.bz2
+```
+
+Then run preprocessing script to create subset sql db and annotation tsv file (may take a while).
+```
+python convert_dataset/create_subset.py --input-dir . --output-dir . --subset yfcc100m_subset_data.tsv
+```
+This script will create two files: SQLite db `yfcc100m_dataset.sql` and annotation tsv file `yfcc14m_dataset.tsv`.
+
+Then follow [YFCC100M Download Instruction](https://gitlab.com/jfolz/yfcc100m/-/tree/master) to download the dataset and meta file.
+```
+pip install git+https://gitlab.com/jfolz/yfcc100m.git
+mkdir -p yfcc100m_meta
+python -m yfcc100m.convert_metadata . -o yfcc100m_meta --skip_verification
+mkdir -p yfcc100m_zip
+python -m yfcc100m.download yfcc100m_meta -o yfcc100m_zip
+```
+
+Finally convert dataset into webdataset format.
+```
+python convert_dataset/convert_yfcc14m.py --root yfcc100m_zip --info yfcc14m_dataset.tsv --shards yfcc14m_shards
+```
+
+### RedCaps12M
+
+Please download the annotation file from [RedCaps](https://redcaps.xyz/).
+```
+wget https://www.dropbox.com/s/cqtdpsl4hewlli1/redcaps_v1.0_annotations.zip?dl=1
+unzip redcaps_v1.0_annotations.zip
+```
+
+Then run preprocessing script and `img2dataset` to download the image text pairs and save in webdataset format.
+```
+python convert_dataset/process_redcaps.py annotations redcaps12m_meta/redcaps12m.parquet --num-split 16
+img2dataset --url_list ~/data/redcaps12m/ --input_format "parquet" \
+            --url_col "URL" --caption_col "TEXT" --output_format webdataset \
+            --output_folder local_data/recaps12m_shards
+            --processes_count 16 --thread_count 64
+            --image_size 512 --resize_mode keep_ratio --resize_only_if_bigger True \
+            --enable_wandb True --save_metadata False --oom_shard_count 6
+rename -d 's/^/redcap12m-/' local_data/recaps12m_shards/*
+```
+
+### ImageNet
+
+Please follow [webdataset ImageNet Example](https://github.com/tmbdev-archive/webdataset-examples/blob/master/makeshards.py) to convert ImageNet into webdataset format.
+
+### Pascal VOC
+
+Please follow [MMSegmentation Pascal VOC Preparation](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#pascal-voc) to download and setup the Pascal VOC dataset.
+
+### Pascal Context
+
+Please refer to [MMSegmentation Pascal Context Preparation](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#pascal-context) to download and setup the Pascal Context dataset.
+
+### COCO
+
+[COCO dataset](https://cocodataset.org/) is an object detection dataset with instance segmentation annotations.
+To evaluate GroupViT, we combine all the instance masks together and generate semantic segmentation maps.
+To generate the semantic segmentation maps, please follow [MMSegmentation's documentation](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#coco-stuff-164k) to download the COCO-Stuff-164k dataset first, then run following
+
+```shell
+python convert_dataset/convert_coco.py local_data/data/coco/ -o local_data/data/coco/
+```
+
+## Run Experiments
+
+### Pre-train
+
+Train on single node:
+
+```shell
+(node0)$ ./tools/dist_train.sh /path/to/config $GPUS_PER_NODE
+```
+
+For example, to train on a node with 8 GPUs, run:
+```shell
+(node0)$ ./tools/dist_train.sh configs/group_vit_gcc_yfcc_30e.yml 8
+```
+
+Train on multiple nodes:
+
+```shell
+(node0)$ ./tools/dist_mn_train.sh /path/to/config $NUM_NODES $NODE_RANK $GPUS_PER_NODE $MASTER_ADDR
+(node1)$ ./tools/dist_mn_train.sh /path/to/config $NUM_NODES $NODE_RANK $GPUS_PER_NODE $MASTER_ADDR
+```
+
+For example, to train on two nodes with 8 GPUs each, run:
+
+```shell
+(node0)$ ./tools/dist_mn_train.sh configs/group_vit_gcc_yfcc_30e.yml 0 2 8 tcp://node0
+(node1)$ ./tools/dist_mn_train.sh configs/group_vit_gcc_yfcc_30e.yml 1 2 8 tcp://node0
+```
+
+We use 16 GPUs for pre-training in our paper.
+
+### Zero-shot Transfer to Semantic Segmentation
+
+#### Pascal VOC
+
+```shell
+./tools/dist_launch.sh main_seg.py /path/to/config $NUM_GPUS --resume /path/to/checkpoint
+```
+
+#### Pascal Context
+
+```shell
+./tools/dist_launch.sh main_seg.py /path/to/config $NUM_GPUS --resume /path/to/checkpoint --opts evaluate.seg.cfg segmentation/configs/_base_/datasets/pascal_context.py
+```
+
+#### COCO
+
+```shell
+./tools/dist_launch.sh main_seg.py /path/to/config $NUM_GPUS --resume /path/to/checkpoint --opts evaluate.seg.cfg segmentation/configs/_base_/datasets/coco.py
+```

+ 109 - 0
configs/default.yml

@@ -0,0 +1,109 @@
+data:
+  batch_size: 256
+  pin_memory: true
+  num_workers: 6
+  # Thomas said it should be at least about 5-10x your batch size; beyond that,
+  # the differences become academic.
+  shuffle_buffer: 10000
+  seed: ${train.seed}
+  dataset:
+    meta:
+      gcc3m:
+        type: img_txt_pair
+        path: local_data/gcc3m_shards
+        prefix: gcc-train-{000000..00436}.tar
+        length: 2891445
+      gcc12m:
+        type: img_txt_pair
+        path: local_data/gcc12m_shards
+        prefix: gcc-conceptual-12m-{000000..001943}.tar
+        length: 11156203
+      yfcc14m:
+        type: img_txt_pair
+        path: local_data/yfcc14m_shards
+        prefix: yfcc14m-{000000..001888}.tar
+        length: 14615499
+      redcap12m:
+        type: img_txt_pair
+        path: local_data/redcap12m_shards
+        prefix: redcap12m-{000000..001211}.tar
+        length: 11866987
+      imagenet:
+        type: img_cls_pair
+        path: local_data/imagenet_shards
+        prefix: imagenet-val-{000000..000049}.tar
+        length: 50000
+    train:
+      - gcc3m
+      - gcc12m
+      - yfcc14m
+    val:
+      - imagenet
+
+  img_aug:
+    deit_aug: true
+    img_size: 224
+    img_scale: [0.08, 1.0]
+    interpolation: bilinear
+    color_jitter: 0.4
+    auto_augment: 'rand-m9-mstd0.5-inc1'
+    re_prob: 0.25
+    re_mode: 'pixel'
+    re_count: 1
+  text_aug:
+    max_seq_len: 77
+    multi_label: 0
+    word_type: 'noun'
+
+train:
+  start_epoch: 0
+  epochs: 30
+  warmup_epochs: 2
+  base_lr: 1.6e-3
+  weight_decay: 0.05
+  warmup_lr: 4e-6
+  min_lr: 4e-5
+  clip_grad: 5.0
+  accumulation_steps: 0
+  amp_opt_level: O1
+  seed: 0
+
+  lr_scheduler:
+    name: cosine
+
+  optimizer:
+    name: adamw
+    eps: 1e-8
+    betas: [0.9, 0.999]
+
+evaluate:
+  eval_only: false
+  eval_freq: 1
+  task:
+    - cls
+    - seg
+  cls:
+    save_best: true
+    template: subset
+  seg:
+    save_best: true
+    cfg: segmentation/configs/_base_/datasets/pascal_voc12.py
+    template: simple
+    opts: []
+
+checkpoint:
+  auto_resume: true
+  resume: ''
+  freq: 1
+  max_kept: -1
+  save_freq: 1
+
+
+model_name: '' # display name in the logger
+output: ???
+tag: default
+print_freq: 10
+seed: 0
+wandb: false
+local_rank: ???
+vis: []

+ 29 - 0
configs/group_vit_gcc_redcap_30e.yml

@@ -0,0 +1,29 @@
+_base_: 'default.yml'
+
+model:
+  type: MultiLabelContrastive
+  img_encoder:
+    type: GroupViT
+    embed_dim: 384
+    num_heads: [6, 6, 6]
+    depths: [6, 3, 3]
+    num_group_tokens: [64, 8, 0]
+    num_output_groups: [64, 8]
+    drop_rate: 0.0
+    drop_path_rate: 0.1
+  text_encoder:
+    type: TextTransformer
+    context_length: 77
+    width: 256
+    layers: 12
+    vocab_size: 49408
+  contrast_temperature: 0.07
+  proj_num_layers: 2
+  output_dim: 256
+  multi_label: ${data.text_aug.multi_label} # multi_label=0 is better for RedCap
+data:
+  dataset:
+    train:
+      - gcc3m
+      - gcc12m
+      - redcap12m

+ 26 - 0
configs/group_vit_gcc_yfcc_30e.yml

@@ -0,0 +1,26 @@
+_base_: 'default.yml'
+data:
+  text_aug:
+    multi_label: 3
+
+model:
+  type: MultiLabelContrastive
+  img_encoder:
+    type: GroupViT
+    embed_dim: 384
+    num_heads: [6, 6, 6]
+    depths: [6, 3, 3]
+    num_group_tokens: [64, 8, 0]
+    num_output_groups: [64, 8]
+    drop_rate: 0.0
+    drop_path_rate: 0.1
+  text_encoder:
+    type: TextTransformer
+    context_length: 77
+    width: 256
+    layers: 12
+    vocab_size: 49408
+  contrast_temperature: 0.07
+  proj_num_layers: 2
+  output_dim: 256
+  multi_label: ${data.text_aug.multi_label}

+ 282 - 0
convert_dataset/convert_coco_object.py

@@ -0,0 +1,282 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import argparse
+import os.path as osp
+import shutil
+from functools import partial
+from glob import glob
+
+import mmcv
+import numpy as np
+from PIL import Image
+
+COCO_LEN = 123287
+
+clsID_to_trID = {
+    0: 0,
+    1: 1,
+    2: 2,
+    3: 3,
+    4: 4,
+    5: 5,
+    6: 6,
+    7: 7,
+    8: 8,
+    9: 9,
+    10: 10,
+    12: 11,
+    13: 12,
+    14: 13,
+    15: 14,
+    16: 15,
+    17: 16,
+    18: 17,
+    19: 18,
+    20: 19,
+    21: 20,
+    22: 21,
+    23: 22,
+    24: 23,
+    26: 24,
+    27: 25,
+    30: 26,
+    31: 27,
+    32: 28,
+    33: 29,
+    34: 30,
+    35: 31,
+    36: 32,
+    37: 33,
+    38: 34,
+    39: 35,
+    40: 36,
+    41: 37,
+    42: 38,
+    43: 39,
+    45: 40,
+    46: 41,
+    47: 42,
+    48: 43,
+    49: 44,
+    50: 45,
+    51: 46,
+    52: 47,
+    53: 48,
+    54: 49,
+    55: 50,
+    56: 51,
+    57: 52,
+    58: 53,
+    59: 54,
+    60: 55,
+    61: 56,
+    62: 57,
+    63: 58,
+    64: 59,
+    66: 60,
+    69: 61,
+    71: 62,
+    72: 63,
+    73: 64,
+    74: 65,
+    75: 66,
+    76: 67,
+    77: 68,
+    78: 69,
+    79: 70,
+    80: 71,
+    81: 72,
+    83: 73,
+    84: 74,
+    85: 75,
+    86: 76,
+    87: 77,
+    88: 78,
+    89: 79,
+    91: 80,
+    92: 81,
+    93: 82,
+    94: 83,
+    95: 84,
+    96: 85,
+    97: 86,
+    98: 87,
+    99: 88,
+    100: 89,
+    101: 90,
+    102: 91,
+    103: 92,
+    104: 93,
+    105: 94,
+    106: 95,
+    107: 96,
+    108: 97,
+    109: 98,
+    110: 99,
+    111: 100,
+    112: 101,
+    113: 102,
+    114: 103,
+    115: 104,
+    116: 105,
+    117: 106,
+    118: 107,
+    119: 108,
+    120: 109,
+    121: 110,
+    122: 111,
+    123: 112,
+    124: 113,
+    125: 114,
+    126: 115,
+    127: 116,
+    128: 117,
+    129: 118,
+    130: 119,
+    131: 120,
+    132: 121,
+    133: 122,
+    134: 123,
+    135: 124,
+    136: 125,
+    137: 126,
+    138: 127,
+    139: 128,
+    140: 129,
+    141: 130,
+    142: 131,
+    143: 132,
+    144: 133,
+    145: 134,
+    146: 135,
+    147: 136,
+    148: 137,
+    149: 138,
+    150: 139,
+    151: 140,
+    152: 141,
+    153: 142,
+    154: 143,
+    155: 144,
+    156: 145,
+    157: 146,
+    158: 147,
+    159: 148,
+    160: 149,
+    161: 150,
+    162: 151,
+    163: 152,
+    164: 153,
+    165: 154,
+    166: 155,
+    167: 156,
+    168: 157,
+    169: 158,
+    170: 159,
+    171: 160,
+    172: 161,
+    173: 162,
+    174: 163,
+    175: 164,
+    176: 165,
+    177: 166,
+    178: 167,
+    179: 168,
+    180: 169,
+    181: 170,
+    255: 255
+}
+
+# set to background
+for k, v in clsID_to_trID.items():
+    clsID_to_trID[k] = v + 1
+    if k > 90:
+        clsID_to_trID[k] = 0
+
+
+def convert_to_trainID(maskpath, out_mask_dir, is_train):
+    mask = np.array(Image.open(maskpath))
+    mask_copy = mask.copy()
+    for clsID, trID in clsID_to_trID.items():
+        mask_copy[mask == clsID] = trID
+    seg_filename = osp.join(
+        out_mask_dir, 'train2017',
+        osp.basename(maskpath).split('.')[0] +
+        '_instanceTrainIds.png') if is_train else osp.join(
+            out_mask_dir, 'val2017',
+            osp.basename(maskpath).split('.')[0] + '_instanceTrainIds.png')
+    Image.fromarray(mask_copy).save(seg_filename, 'PNG')
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description=\
+        'Convert COCO Stuff 164k annotations to COCO Objects')  # noqa
+    parser.add_argument('coco_path', help='coco stuff path')
+    parser.add_argument('-o', '--out_dir', help='output path')
+    parser.add_argument(
+        '--nproc', default=16, type=int, help='number of process')
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    args = parse_args()
+    coco_path = args.coco_path
+    nproc = args.nproc
+
+    out_dir = args.out_dir or coco_path
+    out_img_dir = osp.join(out_dir, 'images')
+    out_mask_dir = osp.join(out_dir, 'annotations')
+
+    mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017'))
+    mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017'))
+
+    if out_dir != coco_path:
+        shutil.copytree(osp.join(coco_path, 'images'), out_img_dir)
+
+    train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png'))
+    train_list = [file for file in train_list if 'TrainIds' not in file]
+    test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png'))
+    test_list = [file for file in test_list if 'TrainIds' not in file]
+    assert (len(train_list) +
+            len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
+                len(train_list), len(test_list))
+
+    if args.nproc > 1:
+        mmcv.track_parallel_progress(
+            partial(
+                convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
+            train_list,
+            nproc=nproc)
+        mmcv.track_parallel_progress(
+            partial(
+                convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
+            test_list,
+            nproc=nproc)
+    else:
+        mmcv.track_progress(
+            partial(
+                convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
+            train_list)
+        mmcv.track_progress(
+            partial(
+                convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
+            test_list)
+
+    print('Done!')
+
+
+if __name__ == '__main__':
+    main()

+ 119 - 0
convert_dataset/convert_yfcc14m.py

@@ -0,0 +1,119 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import argparse
+import json
+import os
+import os.path as osp
+import random
+import sys
+import zipfile
+
+import numpy as np
+import pandas as pd
+import webdataset as wds
+from tqdm import tqdm
+import mmcv
+
+def write_dataset(args):
+
+    df = pd.read_csv(
+        args.info, sep='\t', index_col='file', dtype=str, lineterminator='\n')
+    print(f'Loaded dataframe: \n{df}')
+    print(f'Length: \n{len(df)}')
+
+    # This is the output pattern under which we write shards.
+    pattern = os.path.join(args.shards, f'yfcc14m-%06d.tar')
+
+    with wds.ShardWriter(
+            pattern, maxsize=int(args.maxsize),
+            maxcount=int(args.maxcount)) as sink:
+        sink.verbose = 0
+        all_keys = set()
+
+        skipped = 0
+        zip_files = list(mmcv.scandir(args.root, suffix='zip'))
+        for idx, file in tqdm(
+                enumerate(zip_files), desc='total', total=len(zip_files)):
+            with zipfile.ZipFile(osp.join(args.root, file), 'r') as zfile:
+                filename_list = zfile.namelist()
+                for filename in tqdm(
+                        filename_list, position=1, desc=f'{file}', leave=None):
+                    image = zfile.read(filename)
+                    if image is None:
+                        skipped += 1
+                        tqdm.write(f'Skipping {filename}, {skipped}/{len(df)}')
+                        continue
+                    fname = filename.replace('data/images/', '')
+                    # Construct a unique key from the filename.
+                    key = os.path.splitext(os.path.basename(fname))[0]
+
+                    # Useful check.
+                    if key in all_keys:
+                        tqdm.write(f'duplicate: {fname}')
+                        continue
+                    assert key not in all_keys
+                    all_keys.add(key)
+
+                    text = str(df.loc[fname]['caption'])
+
+                    if len(text.split(' ')) < 2:
+                        skipped += 1
+                        tqdm.write(f'Text {text} too short')
+                        tqdm.write(f'Skipping {fname}, {skipped}/{len(df)}')
+                        continue
+
+                    # Construct a sample.
+                    xkey = key
+                    sample = {'__key__': xkey, 'jpg': image, 'text': text}
+
+                    # Write the sample to the sharded tar archives.
+                    sink.write(sample)
+        print(f'skipped: {skipped}/{len(df)}')
+        print(f'total keys: {len(all_keys)}')
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        """Generate sharded dataset from original ImageNet data.""")
+    parser.add_argument('--maxsize', type=float, default=1e9)
+    parser.add_argument('--maxcount', type=float, default=100000)
+    parser.add_argument('--shards', help='directory where shards are written')
+    parser.add_argument('--root', help='data root path')
+    parser.add_argument('--info', help='tsv path')
+    args = parser.parse_args()
+
+    assert args.maxsize > 10000000
+    assert args.maxcount < 1000000
+    return args
+
+
+def main():
+    args = parse_args()
+
+    seed = 0
+    random.seed(seed)
+    np.random.seed(seed)
+    os.environ['PYTHONHASHSEED'] = str(seed)
+
+    if not os.path.isdir(os.path.join(args.shards, '.')):
+        print(
+            f'{args.shards}: should be a writable destination directory for shards',
+            file=sys.stderr)
+        sys.exit(1)
+
+    write_dataset(args=args)
+
+
+if __name__ == '__main__':
+    main()

+ 114 - 0
convert_dataset/create_subset.py

@@ -0,0 +1,114 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import os
+import os.path as osp
+import argparse
+
+import pandas as pd
+import sqlite3
+import pandas as pd
+import os.path as osp
+from urllib.parse import unquote
+import re
+from datadings.tools import locate_files
+from yfcc100m.vars import FILES
+from yfcc100m.convert_metadata import download_db
+from pandarallel import pandarallel
+
+pandarallel.initialize(progress_bar=True)
+
+
+def key2path(key):
+    img_path = osp.join(key[0:3], key[3:6], key + '.jpg')
+    return img_path
+
+
+def clean_caption(line):
+    line = unquote(str(line))
+    line = remove_html_tags(line)
+    return line.replace('\n', ' ').replace('+', ' ')
+
+
+def remove_html_tags(text):
+    """Remove html tags from a string"""
+    clean = re.compile('<.*?>')
+    return re.sub(clean, '', text)
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Create YFCC subset sql db and tsv')
+    parser.add_argument('--input-dir', help='input sql db file directory')
+    parser.add_argument('--output-dir', help='output tsv directory')
+    parser.add_argument(
+        '--subset', help='subset of data to use', default='yfcc100m_subset_data.tsv')
+    args = parser.parse_args()
+
+    return args
+
+
+def main():
+    args = parse_args()
+
+    files = locate_files(FILES, args.input_dir)
+    # download DB file with AWS tools
+    download_db(files)
+
+    fullset_name = 'yfcc100m_dataset'
+    subset_name = 'yfcc14m_dataset'
+    conn = sqlite3.connect(osp.join(args.input_dir, 'yfcc100m_dataset.sql'))
+    # get column names
+    # some settings that hopefully speed up the queries
+    # conn.execute(f'PRAGMA query_only = YES')
+    conn.execute(f'PRAGMA journal_mode = OFF')
+    conn.execute(f'PRAGMA locking_mode = EXCLUSIVE')
+    conn.execute(f'PRAGMA page_size = 4096')
+    conn.execute(f'PRAGMA mmap_size = {4*1024*1024}')
+    conn.execute(f'PRAGMA cache_size = 10000')
+
+    print('reading subset data')
+    subset_df = pd.read_csv(args.subset, sep='\t', usecols=[1, 2], names=['photoid', 'photo_hash'], index_col='photoid')
+    subset_df.to_sql(subset_name, con=conn, if_exists='replace')
+
+    print('overwriting with subset')
+    select_query = f'select {fullset_name}.*, {subset_name}.photo_hash from {fullset_name} inner join {subset_name} on {fullset_name}.photoid = {subset_name}.photoid'
+    new_name = 'yfcc100m_dataset_new'
+    print('creating new table')
+    conn.execute(f'drop table if exists {new_name}')
+    conn.execute(' '.join([f'create table {new_name} as ', select_query]))
+    print(f'droping {fullset_name}')
+    conn.execute(f'drop table if exists {fullset_name}')
+    print(f'droping {subset_name}')
+    conn.execute(f'drop table if exists {subset_name}')
+    print(f'renaming {new_name} to {fullset_name}')
+    conn.execute(f'alter table {new_name} rename to {fullset_name}')
+    print('vacuuming db')
+    conn.execute('vacuum')
+
+    print(f'Loading dataframe from SQL')
+    anno_df = pd.read_sql(f'select * from {fullset_name}', con=conn)
+    print(f'Loaded dataframe from SQL: \n{anno_df.head()}')
+    print(f'Length: \n{len(anno_df)}')
+    print(f'generating filepath')
+    anno_df['file'] = anno_df['photo_hash'].parallel_map(key2path)
+    anno_df['caption'] = anno_df['description'].parallel_map(clean_caption)
+    anno_df = anno_df[['file', 'caption']]
+    print(f'Generated dataframe: \n{anno_df.head()}')
+
+    print('saving subset as tsv')
+    os.makedirs(args.output_dir, exist_ok=True)
+    anno_df.to_csv(osp.join(args.output_dir, 'yfcc14m_dataset.tsv'), sep='\t', index=False)
+    conn.close()
+
+
+if __name__ == '__main__':
+    main()

+ 71 - 0
convert_dataset/process_redcaps.py

@@ -0,0 +1,71 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import argparse
+import json
+import os
+import os.path as osp
+import random
+
+import pandas as pd
+import pyarrow as pa
+import pyarrow.parquet as pq
+import tqdm
+
+
+def get_args_parser():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        'input', type=str, help='path to redcaps annotations directory')
+    parser.add_argument(
+        'output', type=str, help='output annotations file path')
+    parser.add_argument(
+        '--num-split', type=int, help='number of splits to make')
+    return parser
+
+
+def main(args):
+    annos = []
+    for fname in tqdm.tqdm(os.listdir(args.input), desc='merging json files'):
+        if fname.endswith('json'):
+            with open(os.path.join(args.input, fname)) as f:
+                a = json.load(f)
+                for d in a['annotations']:
+                    cur_d = {'URL': d['url'], 'TEXT': d['caption']}
+                    annos.append(cur_d)
+
+    random.seed(42)
+    random.shuffle(annos)
+    if args.num_split is None:
+        df = pd.DataFrame(annos)
+        print(df.head())
+        print(f'saving {len(df)} annotations to {args.output}')
+        table = pa.Table.from_pandas(df)
+        os.makedirs(osp.dirname(args.output), exist_ok=True)
+        pq.write_table(table, args.output)
+    else:
+        for i in range(args.num_split):
+            df = pd.DataFrame(annos[i::args.num_split])
+            print(df.head())
+            output = osp.splitext(
+                args.output)[0] + f'_part{i}{osp.splitext(args.output)[1]}'
+            print(f'saving {len(df)} annotations to {output}')
+            table = pa.Table.from_pandas(df)
+            os.makedirs(osp.dirname(output), exist_ok=True)
+            pq.write_table(table, output)
+
+
+if __name__ == '__main__':
+    parser = get_args_parser()
+    args = parser.parse_args()
+    main(args)

+ 20 - 0
datasets/__init__.py

@@ -0,0 +1,20 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+
+from .builder import build_loader, build_text_transform
+from .imagenet_template import imagenet_classes, template_meta
+
+__all__ = [
+    'build_loader', build_text_transform, template_meta, imagenet_classes
+]

binární
datasets/bpe_simple_vocab_16e6.txt.gz


+ 337 - 0
datasets/builder.py

@@ -0,0 +1,337 @@
+# -------------------------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE
+#
+# Written by Ze Liu, Zhenda Xie
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import os.path as osp
+import random
+import warnings
+from functools import partial
+
+import nltk
+import numpy as np
+import torch
+import torch.distributed as dist
+import webdataset as wds
+from braceexpand import braceexpand
+from mmcv.parallel import collate
+from timm.data import create_transform
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.data.transforms import _pil_interp
+from torchvision import transforms
+
+from .formatting import ToDataContainer
+from .tokenizer import SimpleTokenizer
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+    # The seed of each worker equals to
+    # num_worker * rank + worker_id + user_seed
+    worker_seed = num_workers * rank + worker_id + seed
+    np.random.seed(worker_seed)
+    random.seed(worker_seed)
+
+
+def build_loader(config):
+    local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0
+
+    dataset_train = build_dataset(is_train=True, config=config)
+    print(f'local rank {local_rank} / global rank {dist.get_rank()} \
+        successfully build train dataset')
+    dataset_val = build_dataset(is_train=False, config=config)
+    print(f'local rank {local_rank} / global rank {dist.get_rank()} \
+        successfully build val dataset')
+
+    dc_collate = partial(collate, samples_per_gpu=config.batch_size)
+    train_len = len(dataset_train)
+    init_fn = partial(worker_init_fn, num_workers=config.num_workers, rank=dist.get_rank(), seed=config.seed)
+    data_loader_train = wds.WebLoader(
+        dataset_train.batched(config.batch_size, dc_collate, partial=False),
+        batch_size=None,
+        shuffle=False,
+        num_workers=config.num_workers,
+        pin_memory=config.pin_memory,
+        persistent_workers=config.num_workers > 0,
+        worker_init_fn=init_fn)
+
+    train_nbatches = max(1, train_len // (config.batch_size * dist.get_world_size()))
+    data_loader_train = (data_loader_train.with_epoch(train_nbatches).with_length(train_nbatches))
+
+    data_loader_val = wds.WebLoader(
+        dataset_val.batched(config.batch_size, dc_collate),
+        batch_size=None,
+        shuffle=False,
+        num_workers=config.num_workers,
+        pin_memory=config.pin_memory,
+        persistent_workers=config.num_workers > 0,
+        worker_init_fn=init_fn)
+
+    val_len = len(dataset_val)
+    val_nbatches = max(1, val_len // (config.batch_size * dist.get_world_size()))
+    data_loader_val = (data_loader_val.with_epoch(val_nbatches).with_length(val_nbatches))
+
+    return dataset_train, dataset_val, data_loader_train, data_loader_val
+
+
+def warn_and_continue(exn):
+    """Call in an exception handler to ignore any exception, issue a warning,
+    and continue."""
+    warnings.warn(repr(exn))
+    return True
+
+
+def build_dataset(is_train, config):
+    img_transform = build_img_transform(is_train, config.img_aug)
+    text_transform = build_text_transform(is_train, config.text_aug)
+    split = 'train' if is_train else 'val'
+    dataset_type = None
+    tar_file_list = []
+    total_length = 0
+    for ds in config.dataset[split]:
+        ds_meta = config.dataset.meta[ds]
+        if dataset_type is None:
+            dataset_type = ds_meta.type
+        else:
+            assert dataset_type == ds_meta.type, \
+                'All datasets must be of the same type'
+
+        prefix = ds_meta.prefix
+        path = ds_meta.path
+        length = ds_meta.length
+        cur_tar_file_list = []
+        for tar_file in braceexpand(osp.join(path, prefix)):
+            if osp.exists(tar_file):
+                cur_tar_file_list.append(tar_file)
+        print(f'Found {len(cur_tar_file_list)} files for dataset {ds}')
+        tar_file_list.extend(cur_tar_file_list)
+        total_length += length
+    print(f'Found {len(tar_file_list)} files in total for split {split}')
+    # yapf: disable
+    if is_train:
+        dataset = (  # noqa
+            wds.WebDataset(tar_file_list, repeat=True, handler=warn_and_continue)
+            .shuffle(config.shuffle_buffer)
+            .decode('pil', handler=warn_and_continue)
+            .rename(image='jpg;png;jpeg', text='text;txt', keep=False, handler=warn_and_continue)
+            .map_dict(image=img_transform, text=text_transform, handler=warn_and_continue)
+            .with_length(total_length))
+    else:
+        # zero shot classification validation
+        dataset = (  # noqa
+            wds.WebDataset(tar_file_list, repeat=False, handler=warn_and_continue)
+            .shuffle(0)
+            .decode('pil', handler=warn_and_continue)
+            .rename(image='jpg;png;jpeg', target='cls', keep=False)
+            .map_dict(image=img_transform, target=ToDataContainer())
+            .slice(dist.get_rank(), total_length, dist.get_world_size())
+            .with_length(total_length))
+    # yapf: enable
+
+    return dataset
+
+
+def build_img_transform(is_train, config, with_dc=True):
+
+    if not config.deit_aug:
+        if is_train:
+            transform = transforms.Compose([
+                transforms.RandomResizedCrop(config.img_size, scale=config.img_scale),
+                transforms.RandomHorizontalFlip(),
+                transforms.ToTensor(),
+                transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
+            ])
+        else:
+            transform = transforms.Compose([
+                transforms.Resize(config.img_size + 32),
+                transforms.CenterCrop(config.img_size),
+                transforms.ToTensor(),
+                transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
+            ])
+
+        return transform
+
+    if is_train:
+        # this should always dispatch to transforms_imagenet_train
+        transform = create_transform(
+            input_size=config.img_size,
+            is_training=True,
+            color_jitter=config.color_jitter if config.color_jitter > 0 else None,
+            auto_augment=config.auto_augment if config.auto_augment != 'none' else None,
+            re_prob=config.re_prob,
+            re_mode=config.re_mode,
+            re_count=config.re_count,
+            interpolation=config.interpolation,
+        )
+    else:
+        size = int((256 / 224) * config.img_size)
+        transform = transforms.Compose([
+            transforms.Resize(size, interpolation=_pil_interp(config.interpolation)),
+            transforms.CenterCrop(config.img_size),
+            transforms.ToTensor(),
+            transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
+        ])
+
+    if with_dc:
+        transform = transforms.Compose([*transform.transforms, ToDataContainer()])
+
+    return transform
+
+
+def build_text_transform(is_train, config, with_dc=True):
+    local_rank = dist.get_rank() % torch.cuda.device_count() if dist.is_initialized() else 0
+    if config.multi_label and is_train:
+        # only down on local rank 0
+        if local_rank == 0:
+            nltk.download('popular')
+        transform = WordAugTokenizeWrapper(
+            Tokenize(SimpleTokenizer(), max_seq_len=config.max_seq_len),
+            max_word=config.multi_label,
+            word_type=config.word_type)
+
+    else:
+        transform = Tokenize(SimpleTokenizer(), max_seq_len=config.max_seq_len)
+
+    if with_dc:
+        transform = transforms.Compose([transform, ToDataContainer()])
+
+    return transform
+
+
+class Tokenize:
+
+    def __init__(self, tokenizer, max_seq_len=77, truncate=True):
+        self.tokenizer = tokenizer
+        self.max_seq_len = max_seq_len
+        self.truncate = truncate
+
+    def __call__(self, texts):
+        expanded_dim = False
+        if isinstance(texts, str):
+            texts = [texts]
+            expanded_dim = True
+
+        sot_token = self.tokenizer.encoder['<|startoftext|>']
+        eot_token = self.tokenizer.encoder['<|endoftext|>']
+        all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
+        result = torch.zeros(len(all_tokens), self.max_seq_len, dtype=torch.long)
+
+        for i, tokens in enumerate(all_tokens):
+            if len(tokens) > self.max_seq_len:
+                if self.truncate:
+                    tokens = tokens[:self.max_seq_len]
+                    tokens[-1] = eot_token
+                else:
+                    raise RuntimeError(f'Input {texts[i]} is too long for context length {self.max_seq_len}')
+            result[i, :len(tokens)] = torch.tensor(tokens)
+
+        if expanded_dim:
+            return result[0]
+
+        return result
+
+
+class WordAugTokenizeWrapper:
+
+    def __init__(self, tokenize, max_word=3, template_set='full', word_type='noun'):
+        self.tokenize = tokenize
+        self.max_word = max_word
+        from .imagenet_template import (full_imagenet_templates, sub_imagenet_template, simple_imagenet_template,
+                                        identity_template)
+        assert template_set in ['full', 'subset', 'simple', 'identity']
+        if template_set == 'full':
+            templates = full_imagenet_templates
+        elif template_set == 'subset':
+            templates = sub_imagenet_template
+        elif template_set == 'simple':
+            templates = simple_imagenet_template
+        elif template_set == 'identity':
+            templates = identity_template
+        else:
+            raise ValueError
+        self.templates = templates
+        assert word_type in ['noun', 'noun_phrase']
+        self.word_type = word_type
+
+    def get_tag(self, tokenized, tags):
+        if not isinstance(tags, (list, tuple)):
+            tags = [tags]
+        ret = []
+        for (word, pos) in nltk.pos_tag(tokenized):
+            for tag in tags:
+                if pos == tag:
+                    ret.append(word)
+        return ret
+
+    def get_noun_phrase(self, tokenized):
+        # Taken from Su Nam Kim Paper...
+        grammar = r"""
+            NBAR:
+                {<NN.*|JJ>*<NN.*>}  # Nouns and Adjectives, terminated with Nouns
+
+            NP:
+                {<NBAR>}
+                {<NBAR><IN><NBAR>}  # Above, connected with in/of/etc...
+        """
+        chunker = nltk.RegexpParser(grammar)
+
+        chunked = chunker.parse(nltk.pos_tag(tokenized))
+        continuous_chunk = []
+        current_chunk = []
+
+        for subtree in chunked:
+            if isinstance(subtree, nltk.Tree):
+                current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
+            elif current_chunk:
+                named_entity = ' '.join(current_chunk)
+                if named_entity not in continuous_chunk:
+                    continuous_chunk.append(named_entity)
+                    current_chunk = []
+            else:
+                continue
+
+        return continuous_chunk
+
+    def __call__(self, text):
+        assert isinstance(text, str)
+        tokenized = nltk.word_tokenize(text)
+        nouns = []
+        if len(tokenized) > 0:
+            if self.word_type == 'noun':
+                nouns = self.get_tag(tokenized, ['NN', 'NNS', 'NNP', 'VBG', 'VB', 'VBD', 'VBN', 'VBP', 'VBZ'])
+            elif self.word_type == 'noun_phrase':
+                nouns = self.get_noun_phrase(tokenized)
+            else:
+                raise ValueError('word_type must be noun or noun_phrase')
+
+        prompt_texts = []
+        if len(nouns) > 0:
+            select_nouns = np.random.choice(nouns, min(self.max_word, len(nouns)), replace=False)
+            prompt_texts = [np.random.choice(self.templates).format(noun) for noun in select_nouns]
+        if len(prompt_texts) < self.max_word:
+            prompt_texts += [text] * (self.max_word - len(prompt_texts))
+
+        texts = [text] + prompt_texts
+        return self.tokenize(texts)

+ 36 - 0
datasets/formatting.py

@@ -0,0 +1,36 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import torch
+from mmcv.parallel import DataContainer as DC
+
+
+class ToDataContainer(object):
+    """Convert results to :obj:`mmcv.DataContainer`"""
+
+    def __call__(self, sample):
+        """Call function to convert data in results to
+        :obj:`mmcv.DataContainer`.
+
+        Args:
+            sample (torch.Tensor): Input sample.
+
+        Returns:
+            DataContainer
+        """
+        if isinstance(sample, int):
+            sample = torch.tensor(sample)
+        return DC(sample, stack=True, pad_dims=None)
+
+    def __repr__(self):
+        return self.__class__.__name__

+ 267 - 0
datasets/imagenet_template.py

@@ -0,0 +1,267 @@
+# -------------------------------------------------------------------------
+# MIT License
+#
+# Copyright (c) 2021 OpenAI
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+full_imagenet_templates = [
+    'a bad photo of a {}.',
+    'a photo of many {}.',
+    'a sculpture of a {}.',
+    'a photo of the hard to see {}.',
+    'a low resolution photo of the {}.',
+    'a rendering of a {}.',
+    'graffiti of a {}.',
+    'a bad photo of the {}.',
+    'a cropped photo of the {}.',
+    'a tattoo of a {}.',
+    'the embroidered {}.',
+    'a photo of a hard to see {}.',
+    'a bright photo of a {}.',
+    'a photo of a clean {}.',
+    'a photo of a dirty {}.',
+    'a dark photo of the {}.',
+    'a drawing of a {}.',
+    'a photo of my {}.',
+    'the plastic {}.',
+    'a photo of the cool {}.',
+    'a close-up photo of a {}.',
+    'a black and white photo of the {}.',
+    'a painting of the {}.',
+    'a painting of a {}.',
+    'a pixelated photo of the {}.',
+    'a sculpture of the {}.',
+    'a bright photo of the {}.',
+    'a cropped photo of a {}.',
+    'a plastic {}.',
+    'a photo of the dirty {}.',
+    'a jpeg corrupted photo of a {}.',
+    'a blurry photo of the {}.',
+    'a photo of the {}.',
+    'a good photo of the {}.',
+    'a rendering of the {}.',
+    'a {} in a video game.',
+    'a photo of one {}.',
+    'a doodle of a {}.',
+    'a close-up photo of the {}.',
+    'a photo of a {}.',
+    'the origami {}.',
+    'the {} in a video game.',
+    'a sketch of a {}.',
+    'a doodle of the {}.',
+    'a origami {}.',
+    'a low resolution photo of a {}.',
+    'the toy {}.',
+    'a rendition of the {}.',
+    'a photo of the clean {}.',
+    'a photo of a large {}.',
+    'a rendition of a {}.',
+    'a photo of a nice {}.',
+    'a photo of a weird {}.',
+    'a blurry photo of a {}.',
+    'a cartoon {}.',
+    'art of a {}.',
+    'a sketch of the {}.',
+    'a embroidered {}.',
+    'a pixelated photo of a {}.',
+    'itap of the {}.',
+    'a jpeg corrupted photo of the {}.',
+    'a good photo of a {}.',
+    'a plushie {}.',
+    'a photo of the nice {}.',
+    'a photo of the small {}.',
+    'a photo of the weird {}.',
+    'the cartoon {}.',
+    'art of the {}.',
+    'a drawing of the {}.',
+    'a photo of the large {}.',
+    'a black and white photo of a {}.',
+    'the plushie {}.',
+    'a dark photo of a {}.',
+    'itap of a {}.',
+    'graffiti of the {}.',
+    'a toy {}.',
+    'itap of my {}.',
+    'a photo of a cool {}.',
+    'a photo of a small {}.',
+    'a tattoo of the {}.',
+]
+
+sub_imagenet_template = [
+    'itap of a {}.', 'a bad photo of a {}.', 'a origami {}.', 'a photo of the large {}.', 'a {} in a video game.',
+    'art of the {}.', 'a photo of the small {}.'
+]
+
+simple_imagenet_template = [
+    'a photo of a {}.',
+]
+
+identity_template = [
+    '{}',
+]
+
+template_meta = {
+    'full': full_imagenet_templates,
+    'subset': sub_imagenet_template,
+    'simple': simple_imagenet_template,
+    'identity': identity_template,
+}
+
+imagenet_classes = [
+    'tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead shark', 'electric ray', 'stingray', 'rooster',
+    'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'American robin', 'bulbul',
+    'jay', 'magpie', 'chickadee', 'American dipper', 'kite (bird of prey)', 'bald eagle', 'vulture', 'great grey owl',
+    'fire salamander', 'smooth newt', 'newt', 'spotted salamander', 'axolotl', 'American bullfrog', 'tree frog',
+    'tailed frog', 'loggerhead sea turtle', 'leatherback sea turtle', 'mud turtle', 'terrapin', 'box turtle',
+    'banded gecko', 'green iguana', 'Carolina anole', 'desert grassland whiptail lizard', 'agama',
+    'frilled-necked lizard', 'alligator lizard', 'Gila monster', 'European green lizard', 'chameleon', 'Komodo dragon',
+    'Nile crocodile', 'American alligator', 'triceratops', 'worm snake', 'ring-necked snake', 'eastern hog-nosed snake',
+    'smooth green snake', 'kingsnake', 'garter snake', 'water snake', 'vine snake', 'night snake', 'boa constrictor',
+    'African rock python', 'Indian cobra', 'green mamba', 'sea snake', 'Saharan horned viper',
+    'eastern diamondback rattlesnake', 'sidewinder rattlesnake', 'trilobite', 'harvestman', 'scorpion',
+    'yellow garden spider', 'barn spider', 'European garden spider', 'southern black widow', 'tarantula', 'wolf spider',
+    'tick', 'centipede', 'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie grouse', 'peafowl', 'quail',
+    'partridge', 'african grey parrot', 'macaw', 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater',
+    'hornbill', 'hummingbird', 'jacamar', 'toucan', 'duck', 'red-breasted merganser', 'goose', 'black swan', 'tusker',
+    'echidna', 'platypus', 'wallaby', 'koala', 'wombat', 'jellyfish', 'sea anemone', 'brain coral', 'flatworm',
+    'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton', 'chambered nautilus', 'Dungeness crab', 'rock crab',
+    'fiddler crab', 'red king crab', 'American lobster', 'spiny lobster', 'crayfish', 'hermit crab', 'isopod',
+    'white stork', 'black stork', 'spoonbill', 'flamingo', 'little blue heron', 'great egret', 'bittern bird',
+    'crane bird', 'limpkin', 'common gallinule', 'American coot', 'bustard', 'ruddy turnstone', 'dunlin',
+    'common redshank', 'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross', 'grey whale',
+    'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese Chin', 'Maltese', 'Pekingese', 'Shih Tzu',
+    'King Charles Spaniel', 'Papillon', 'toy terrier', 'Rhodesian Ridgeback', 'Afghan Hound', 'Basset Hound', 'Beagle',
+    'Bloodhound', 'Bluetick Coonhound', 'Black and Tan Coonhound', 'Treeing Walker Coonhound', 'English foxhound',
+    'Redbone Coonhound', 'borzoi', 'Irish Wolfhound', 'Italian Greyhound', 'Whippet', 'Ibizan Hound',
+    'Norwegian Elkhound', 'Otterhound', 'Saluki', 'Scottish Deerhound', 'Weimaraner', 'Staffordshire Bull Terrier',
+    'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier', 'Kerry Blue Terrier', 'Irish Terrier',
+    'Norfolk Terrier', 'Norwich Terrier', 'Yorkshire Terrier', 'Wire Fox Terrier', 'Lakeland Terrier',
+    'Sealyham Terrier', 'Airedale Terrier', 'Cairn Terrier', 'Australian Terrier', 'Dandie Dinmont Terrier',
+    'Boston Terrier', 'Miniature Schnauzer', 'Giant Schnauzer', 'Standard Schnauzer', 'Scottish Terrier',
+    'Tibetan Terrier', 'Australian Silky Terrier', 'Soft-coated Wheaten Terrier', 'West Highland White Terrier',
+    'Lhasa Apso', 'Flat-Coated Retriever', 'Curly-coated Retriever', 'Golden Retriever', 'Labrador Retriever',
+    'Chesapeake Bay Retriever', 'German Shorthaired Pointer', 'Vizsla', 'English Setter', 'Irish Setter',
+    'Gordon Setter', 'Brittany dog', 'Clumber Spaniel', 'English Springer Spaniel', 'Welsh Springer Spaniel',
+    'Cocker Spaniel', 'Sussex Spaniel', 'Irish Water Spaniel', 'Kuvasz', 'Schipperke', 'Groenendael dog', 'Malinois',
+    'Briard', 'Australian Kelpie', 'Komondor', 'Old English Sheepdog', 'Shetland Sheepdog', 'collie', 'Border Collie',
+    'Bouvier des Flandres dog', 'Rottweiler', 'German Shepherd Dog', 'Dobermann', 'Miniature Pinscher',
+    'Greater Swiss Mountain Dog', 'Bernese Mountain Dog', 'Appenzeller Sennenhund', 'Entlebucher Sennenhund', 'Boxer',
+    'Bullmastiff', 'Tibetan Mastiff', 'French Bulldog', 'Great Dane', 'St. Bernard', 'husky', 'Alaskan Malamute',
+    'Siberian Husky', 'Dalmatian', 'Affenpinscher', 'Basenji', 'pug', 'Leonberger', 'Newfoundland dog',
+    'Great Pyrenees dog', 'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', 'brussels griffon', 'Pembroke Welsh Corgi',
+    'Cardigan Welsh Corgi', 'Toy Poodle', 'Miniature Poodle', 'Standard Poodle',
+    'Mexican hairless dog (xoloitzcuintli)', 'grey wolf', 'Alaskan tundra wolf', 'red wolf or maned wolf', 'coyote',
+    'dingo', 'dhole', 'African wild dog', 'hyena', 'red fox', 'kit fox', 'Arctic fox', 'grey fox', 'tabby cat',
+    'tiger cat', 'Persian cat', 'Siamese cat', 'Egyptian Mau', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar',
+    'lion', 'tiger', 'cheetah', 'brown bear', 'American black bear', 'polar bear', 'sloth bear', 'mongoose', 'meerkat',
+    'tiger beetle', 'ladybug', 'ground beetle', 'longhorn beetle', 'leaf beetle', 'dung beetle', 'rhinoceros beetle',
+    'weevil', 'fly', 'bee', 'ant', 'grasshopper', 'cricket insect', 'stick insect', 'cockroach', 'praying mantis',
+    'cicada', 'leafhopper', 'lacewing', 'dragonfly', 'damselfly', 'red admiral butterfly', 'ringlet butterfly',
+    'monarch butterfly', 'small white butterfly', 'sulphur butterfly', 'gossamer-winged butterfly', 'starfish',
+    'sea urchin', 'sea cucumber', 'cottontail rabbit', 'hare', 'Angora rabbit', 'hamster', 'porcupine', 'fox squirrel',
+    'marmot', 'beaver', 'guinea pig', 'common sorrel horse', 'zebra', 'pig', 'wild boar', 'warthog', 'hippopotamus',
+    'ox', 'water buffalo', 'bison', 'ram (adult male sheep)', 'bighorn sheep', 'Alpine ibex', 'hartebeest',
+    'impala (antelope)', 'gazelle', 'arabian camel', 'llama', 'weasel', 'mink', 'European polecat',
+    'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo', 'three-toed sloth', 'orangutan', 'gorilla',
+    'chimpanzee', 'gibbon', 'siamang', 'guenon', 'patas monkey', 'baboon', 'macaque', 'langur',
+    'black-and-white colobus', 'proboscis monkey', 'marmoset', 'white-headed capuchin', 'howler monkey', 'titi monkey',
+    "Geoffroy's spider monkey", 'common squirrel monkey', 'ring-tailed lemur', 'indri', 'Asian elephant',
+    'African bush elephant', 'red panda', 'giant panda', 'snoek fish', 'eel', 'silver salmon', 'rock beauty fish',
+    'clownfish', 'sturgeon', 'gar fish', 'lionfish', 'pufferfish', 'abacus', 'abaya', 'academic gown', 'accordion',
+    'acoustic guitar', 'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance', 'amphibious vehicle',
+    'analog clock', 'apiary', 'apron', 'trash can', 'assault rifle', 'backpack', 'bakery', 'balance beam', 'balloon',
+    'ballpoint pen', 'Band-Aid', 'banjo', 'baluster / handrail', 'barbell', 'barber chair', 'barbershop', 'barn',
+    'barometer', 'barrel', 'wheelbarrow', 'baseball', 'basketball', 'bassinet', 'bassoon', 'swimming cap', 'bath towel',
+    'bathtub', 'station wagon', 'lighthouse', 'beaker', 'military hat (bearskin or shako)', 'beer bottle', 'beer glass',
+    'bell tower', 'baby bib', 'tandem bicycle', 'bikini', 'ring binder', 'binoculars', 'birdhouse', 'boathouse',
+    'bobsleigh', 'bolo tie', 'poke bonnet', 'bookcase', 'bookstore', 'bottle cap', 'hunting bow', 'bow tie',
+    'brass memorial plaque', 'bra', 'breakwater', 'breastplate', 'broom', 'bucket', 'buckle', 'bulletproof vest',
+    'high-speed train', 'butcher shop', 'taxicab', 'cauldron', 'candle', 'cannon', 'canoe', 'can opener', 'cardigan',
+    'car mirror', 'carousel', 'tool kit', 'cardboard box / carton', 'car wheel', 'automated teller machine', 'cassette',
+    'cassette player', 'castle', 'catamaran', 'CD player', 'cello', 'mobile phone', 'chain', 'chain-link fence',
+    'chain mail', 'chainsaw', 'storage chest', 'chiffonier', 'bell or wind chime', 'china cabinet',
+    'Christmas stocking', 'church', 'movie theater', 'cleaver', 'cliff dwelling', 'cloak', 'clogs', 'cocktail shaker',
+    'coffee mug', 'coffeemaker', 'spiral or coil', 'combination lock', 'computer keyboard', 'candy store',
+    'container ship', 'convertible', 'corkscrew', 'cornet', 'cowboy boot', 'cowboy hat', 'cradle', 'construction crane',
+    'crash helmet', 'crate', 'infant bed', 'Crock Pot', 'croquet ball', 'crutch', 'cuirass', 'dam', 'desk',
+    'desktop computer', 'rotary dial telephone', 'diaper', 'digital clock', 'digital watch', 'dining table',
+    'dishcloth', 'dishwasher', 'disc brake', 'dock', 'dog sled', 'dome', 'doormat', 'drilling rig', 'drum', 'drumstick',
+    'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar', 'electric locomotive', 'entertainment center',
+    'envelope', 'espresso machine', 'face powder', 'feather boa', 'filing cabinet', 'fireboat', 'fire truck',
+    'fire screen', 'flagpole', 'flute', 'folding chair', 'football helmet', 'forklift', 'fountain', 'fountain pen',
+    'four-poster bed', 'freight car', 'French horn', 'frying pan', 'fur coat', 'garbage truck',
+    'gas mask or respirator', 'gas pump', 'goblet', 'go-kart', 'golf ball', 'golf cart', 'gondola', 'gong', 'gown',
+    'grand piano', 'greenhouse', 'radiator grille', 'grocery store', 'guillotine', 'hair clip', 'hair spray',
+    'half-track', 'hammer', 'hamper', 'hair dryer', 'hand-held computer', 'handkerchief', 'hard disk drive',
+    'harmonica', 'harp', 'combine harvester', 'hatchet', 'holster', 'home theater', 'honeycomb', 'hook', 'hoop skirt',
+    'gymnastic horizontal bar', 'horse-drawn vehicle', 'hourglass', 'iPod', 'clothes iron', 'carved pumpkin', 'jeans',
+    'jeep', 'T-shirt', 'jigsaw puzzle', 'rickshaw', 'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle',
+    'lampshade', 'laptop computer', 'lawn mower', 'lens cap', 'letter opener', 'library', 'lifeboat', 'lighter',
+    'limousine', 'ocean liner', 'lipstick', 'slip-on shoe', 'lotion', 'music speaker', 'loupe magnifying glass',
+    'sawmill', 'magnetic compass', 'messenger bag', 'mailbox', 'tights', 'one-piece bathing suit', 'manhole cover',
+    'maraca', 'marimba', 'mask', 'matchstick', 'maypole', 'maze', 'measuring cup', 'medicine cabinet', 'megalith',
+    'microphone', 'microwave oven', 'military uniform', 'milk can', 'minibus', 'miniskirt', 'minivan', 'missile',
+    'mitten', 'mixing bowl', 'mobile home', 'ford model t', 'modem', 'monastery', 'monitor', 'moped',
+    'mortar and pestle', 'graduation cap', 'mosque', 'mosquito net', 'vespa', 'mountain bike', 'tent', 'computer mouse',
+    'mousetrap', 'moving van', 'muzzle', 'metal nail', 'neck brace', 'necklace', 'baby pacifier', 'notebook computer',
+    'obelisk', 'oboe', 'ocarina', 'odometer', 'oil filter', 'pipe organ', 'oscilloscope', 'overskirt', 'bullock cart',
+    'oxygen mask', 'product packet / packaging', 'paddle', 'paddle wheel', 'padlock', 'paintbrush', 'pajamas', 'palace',
+    'pan flute', 'paper towel', 'parachute', 'parallel bars', 'park bench', 'parking meter', 'railroad car', 'patio',
+    'payphone', 'pedestal', 'pencil case', 'pencil sharpener', 'perfume', 'Petri dish', 'photocopier', 'plectrum',
+    'Pickelhaube', 'picket fence', 'pickup truck', 'pier', 'piggy bank', 'pill bottle', 'pillow', 'ping-pong ball',
+    'pinwheel', 'pirate ship', 'drink pitcher', 'block plane', 'planetarium', 'plastic bag', 'plate rack', 'farm plow',
+    'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho', 'pool table', 'soda bottle', 'plant pot',
+    "potter's wheel", 'power drill', 'prayer rug', 'printer', 'prison', 'missile', 'projector', 'hockey puck',
+    'punching bag', 'purse', 'quill', 'quilt', 'race car', 'racket', 'radiator', 'radio', 'radio telescope',
+    'rain barrel', 'recreational vehicle', 'fishing casting reel', 'reflex camera', 'refrigerator', 'remote control',
+    'restaurant', 'revolver', 'rifle', 'rocking chair', 'rotisserie', 'eraser', 'rugby ball', 'ruler measuring stick',
+    'sneaker', 'safe', 'safety pin', 'salt shaker', 'sandal', 'sarong', 'saxophone', 'scabbard', 'weighing scale',
+    'school bus', 'schooner', 'scoreboard', 'CRT monitor', 'screw', 'screwdriver', 'seat belt', 'sewing machine',
+    'shield', 'shoe store', 'shoji screen / room divider', 'shopping basket', 'shopping cart', 'shovel', 'shower cap',
+    'shower curtain', 'ski', 'balaclava ski mask', 'sleeping bag', 'slide rule', 'sliding door', 'slot machine',
+    'snorkel', 'snowmobile', 'snowplow', 'soap dispenser', 'soccer ball', 'sock', 'solar thermal collector', 'sombrero',
+    'soup bowl', 'keyboard space bar', 'space heater', 'space shuttle', 'spatula', 'motorboat', 'spider web', 'spindle',
+    'sports car', 'spotlight', 'stage', 'steam locomotive', 'through arch bridge', 'steel drum', 'stethoscope', 'scarf',
+    'stone wall', 'stopwatch', 'stove', 'strainer', 'tram', 'stretcher', 'couch', 'stupa', 'submarine', 'suit',
+    'sundial', 'sunglasses', 'sunglasses', 'sunscreen', 'suspension bridge', 'mop', 'sweatshirt',
+    'swim trunks / shorts', 'swing', 'electrical switch', 'syringe', 'table lamp', 'tank', 'tape player', 'teapot',
+    'teddy bear', 'television', 'tennis ball', 'thatched roof', 'front curtain', 'thimble', 'threshing machine',
+    'throne', 'tile roof', 'toaster', 'tobacco shop', 'toilet seat', 'torch', 'totem pole', 'tow truck', 'toy store',
+    'tractor', 'semi-trailer truck', 'tray', 'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch',
+    'trolleybus', 'trombone', 'hot tub', 'turnstile', 'typewriter keyboard', 'umbrella', 'unicycle', 'upright piano',
+    'vacuum cleaner', 'vase', 'vaulted or arched ceiling', 'velvet fabric', 'vending machine', 'vestment', 'viaduct',
+    'violin', 'volleyball', 'waffle iron', 'wall clock', 'wallet', 'wardrobe', 'military aircraft', 'sink',
+    'washing machine', 'water bottle', 'water jug', 'water tower', 'whiskey jug', 'whistle', 'hair wig',
+    'window screen', 'window shade', 'Windsor tie', 'wine bottle', 'airplane wing', 'wok', 'wooden spoon', 'wool',
+    'split-rail fence', 'shipwreck', 'sailboat', 'yurt', 'website', 'comic book', 'crossword', 'traffic or street sign',
+    'traffic light', 'dust jacket', 'menu', 'plate', 'guacamole', 'consomme', 'hot pot', 'trifle', 'ice cream',
+    'popsicle', 'baguette', 'bagel', 'pretzel', 'cheeseburger', 'hot dog', 'mashed potatoes', 'cabbage', 'broccoli',
+    'cauliflower', 'zucchini', 'spaghetti squash', 'acorn squash', 'butternut squash', 'cucumber', 'artichoke',
+    'bell pepper', 'cardoon', 'mushroom', 'Granny Smith apple', 'strawberry', 'orange', 'lemon', 'fig', 'pineapple',
+    'banana', 'jackfruit', 'cherimoya (custard apple)', 'pomegranate', 'hay', 'carbonara', 'chocolate syrup', 'dough',
+    'meatloaf', 'pizza', 'pot pie', 'burrito', 'red wine', 'espresso', 'tea cup', 'eggnog', 'mountain', 'bubble',
+    'cliff', 'coral reef', 'geyser', 'lakeshore', 'promontory', 'sandbar', 'beach', 'valley', 'volcano',
+    'baseball player', 'bridegroom', 'scuba diver', 'rapeseed', 'daisy', "yellow lady's slipper", 'corn', 'acorn',
+    'rose hip', 'horse chestnut seed', 'coral fungus', 'agaric', 'gyromitra', 'stinkhorn mushroom', 'earth star fungus',
+    'hen of the woods mushroom', 'bolete', 'corn cob', 'toilet paper'
+]

+ 160 - 0
datasets/tokenizer.py

@@ -0,0 +1,160 @@
+# -------------------------------------------------------------------------
+# MIT License
+#
+# Copyright (c) 2021 OpenAI
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+    return os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bpe_simple_vocab_16e6.txt.gz')
+
+
+@lru_cache()
+def bytes_to_unicode():
+    """Returns list of utf-8 byte and a corresponding list of unicode strings.
+
+    The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+    if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent
+    coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables
+    between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = list(range(ord('!'), ord('~') + 1)) + list(range(ord('¡'), ord('¬') + 1)) + list(range(ord('®'), ord('ÿ') + 1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8 + n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+    """Return set of symbol pairs in a word.
+
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+def basic_clean(text):
+    text = ftfy.fix_text(text)
+    text = html.unescape(html.unescape(text))
+    return text.strip()
+
+
+def whitespace_clean(text):
+    text = re.sub(r'\s+', ' ', text)
+    text = text.strip()
+    return text
+
+
+class SimpleTokenizer(object):
+
+    def __init__(self, bpe_path: str = default_bpe()):
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+        merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
+        merges = merges[1:49152 - 256 - 2 + 1]
+        merges = [tuple(merge.split()) for merge in merges]
+        vocab = list(bytes_to_unicode().values())
+        vocab = vocab + [v + '</w>' for v in vocab]
+        for merge in merges:
+            vocab.append(''.join(merge))
+        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+        self.encoder = dict(zip(vocab, range(len(vocab))))
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.bpe_ranks = dict(zip(merges, range(len(merges))))
+        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+        self.pat = re.compile(
+            r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+            re.IGNORECASE)
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token[:-1]) + (token[-1] + '</w>', )
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token + '</w>'
+
+        while True:
+            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:  # noqa: E722
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+                    new_word.append(first + second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = ' '.join(word)
+        self.cache[token] = word
+        return word
+
+    def encode(self, text):
+        bpe_tokens = []
+        text = whitespace_clean(basic_clean(text)).lower()
+        for token in re.findall(self.pat, text):
+            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+        return bpe_tokens
+
+    def decode(self, tokens):
+        text = ''.join([self.decoder[token] for token in tokens])
+        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace').replace('</w>', ' ')
+        return text

+ 142 - 0
demo/demo_seg.py

@@ -0,0 +1,142 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import argparse
+import os.path as osp
+import sys
+
+parentdir = osp.dirname(osp.dirname(__file__))
+sys.path.insert(0, parentdir)
+
+import mmcv
+import torch
+from datasets import build_text_transform
+from mmcv.cnn.utils import revert_sync_batchnorm
+from mmcv.image import tensor2imgs
+from mmcv.parallel import collate, scatter
+from models import build_model
+from omegaconf import read_write
+from segmentation.datasets import COCOObjectDataset, PascalContextDataset, PascalVOCDataset
+from segmentation.evaluation import build_seg_demo_pipeline, build_seg_inference
+from utils import get_config, load_checkpoint
+
+def parse_args():
+    parser = argparse.ArgumentParser('GroupViT demo')
+    parser.add_argument(
+        '--cfg',
+        type=str,
+        required=True,
+        help='path to config file',
+    )
+    parser.add_argument(
+        '--opts',
+        help="Modify config options by adding 'KEY VALUE' pairs. ",
+        default=None,
+        nargs='+',
+    )
+
+    parser.add_argument('--resume', help='resume from checkpoint')
+    parser.add_argument(
+        '--vis',
+        help='Specify the visualization mode, '
+        'could be a list, support "input", "pred", "input_pred", "all_groups", "first_group", "final_group", "input_pred_label"',
+        default=None,
+        nargs='+')
+
+    parser.add_argument('--device', default='cuda:0', help='Device used for inference')
+    parser.add_argument(
+        '--dataset', default='voc', choices=['voc', 'coco', 'context'], help='dataset classes for visualization')
+
+    parser.add_argument('--input', type=str, help='input image path')
+    parser.add_argument('--output_dir', type=str, help='output dir')
+
+    args = parser.parse_args()
+    args.local_rank = 0  # compatible with config
+
+    return args
+
+
+def inference(args, cfg):
+    model = build_model(cfg.model)
+    model = revert_sync_batchnorm(model)
+    model.to(args.device)
+    model.eval()
+
+    load_checkpoint(cfg, model, None, None)
+
+    text_transform = build_text_transform(False, cfg.data.text_aug, with_dc=False)
+    if args.dataset == 'voc':
+        dataset_class = PascalVOCDataset
+        seg_cfg = 'segmentation/configs/_base_/datasets/pascal_voc12.py'
+    elif args.dataset == 'coco':
+        dataset_class = COCOObjectDataset
+        seg_cfg = 'segmentation/configs/_base_/datasets/coco_object164k.py'
+    elif args.dataset == 'context':
+        dataset_class = PascalContextDataset
+        seg_cfg = 'segmentation/configs/_base_/datasets/pascal_context.py'
+    else:
+        raise ValueError('Unknown dataset: {}'.format(args.dataset))
+
+    with read_write(cfg):
+        cfg.evaluate.seg.cfg = seg_cfg
+        cfg.evaluate.seg.opts = ['test_cfg.mode=whole']
+
+    seg_model = build_seg_inference(model, dataset_class, text_transform, cfg.evaluate.seg)
+
+    vis_seg(seg_model, args.input, args.output_dir, args.vis)
+
+
+def vis_seg(seg_model, input_img, output_dir, vis_modes):
+    device = next(seg_model.parameters()).device
+    test_pipeline = build_seg_demo_pipeline()
+    # prepare data
+    data = dict(img=input_img)
+    data = test_pipeline(data)
+    data = collate([data], samples_per_gpu=1)
+    if next(seg_model.parameters()).is_cuda:
+        # scatter to specified GPU
+        data = scatter(data, [device])[0]
+    else:
+        data['img_metas'] = [i.data[0] for i in data['img_metas']]
+    with torch.no_grad():
+        result = seg_model(return_loss=False, rescale=True, **data)
+
+    img_tensor = data['img'][0]
+    img_metas = data['img_metas'][0]
+    imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
+    assert len(imgs) == len(img_metas)
+
+    for img, img_meta in zip(imgs, img_metas):
+        h, w, _ = img_meta['img_shape']
+        img_show = img[:h, :w, :]
+
+        ori_h, ori_w = img_meta['ori_shape'][:-1]
+        img_show = mmcv.imresize(img_show, (ori_w, ori_h))
+
+        for vis_mode in vis_modes:
+            out_file = osp.join(output_dir, 'vis_imgs', vis_mode, f'{vis_mode}.jpg')
+            seg_model.show_result(img_show, img_tensor.to(device), result, out_file, vis_mode)
+
+
+def main():
+    args = parse_args()
+    cfg = get_config(args)
+
+    with read_write(cfg):
+        cfg.evaluate.eval_only = True
+
+    inference(args, cfg)
+
+
+if __name__ == '__main__':
+    main()

binární
demo/examples/coco.jpg


binární
demo/examples/ctx.jpg


binární
demo/examples/voc.jpg


binární
figs/github_arch.gif


binární
figs/github_coco.gif


binární
figs/github_ctx.gif


binární
figs/github_voc.gif


+ 460 - 0
main_group_vit.py

@@ -0,0 +1,460 @@
+# -------------------------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE
+#
+# Written by Ze Liu, Zhenda Xie
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import argparse
+import datetime
+import os
+import os.path as osp
+import time
+from collections import defaultdict
+
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from datasets import build_loader, build_text_transform, imagenet_classes
+from mmcv.parallel import MMDistributedDataParallel
+from mmcv.runner import get_dist_info, init_dist, set_random_seed
+from mmcv.utils import collect_env, get_git_hash
+from mmseg.apis import multi_gpu_test
+from models import build_model
+from omegaconf import OmegaConf, read_write
+from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
+from timm.utils import AverageMeter, accuracy
+from utils import (auto_resume_helper, build_dataset_class_tokens, build_optimizer, build_scheduler, data2cuda,
+                   get_config, get_grad_norm, get_logger, load_checkpoint, parse_losses, reduce_tensor, save_checkpoint)
+
+try:
+    # noinspection PyUnresolvedReferences
+    from apex import amp
+except ImportError:
+    amp = None
+
+
+def parse_args():
+    parser = argparse.ArgumentParser('GroupViT training and evaluation script')
+    parser.add_argument('--cfg', type=str, required=True, help='path to config file')
+    parser.add_argument('--opts', help="Modify config options by adding 'KEY=VALUE' list. ", default=None, nargs='+')
+
+    # easy config modification
+    parser.add_argument('--batch-size', type=int, help='batch size for single GPU')
+    parser.add_argument('--resume', help='resume from checkpoint')
+    parser.add_argument(
+        '--amp-opt-level',
+        type=str,
+        default='O1',
+        choices=['O0', 'O1', 'O2'],
+        help='mixed precision opt level, if O0, no amp is used')
+    parser.add_argument(
+        '--output', type=str, help='root of output folder, '
+        'the full path is <output>/<model_name>/<tag>')
+    parser.add_argument('--tag', type=str, help='tag of experiment')
+    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
+    parser.add_argument('--wandb', action='store_true', help='Use W&B to log experiments')
+    parser.add_argument('--keep', type=int, help='Maximum checkpoint to keep')
+
+    # distributed training
+    parser.add_argument('--local_rank', type=int, required=True, help='local rank for DistributedDataParallel')
+
+    args = parser.parse_args()
+
+    return args
+
+
+def train(cfg):
+    if cfg.wandb and dist.get_rank() == 0:
+        import wandb
+        wandb.init(
+            project='group_vit',
+            name=osp.join(cfg.model_name, cfg.tag),
+            dir=cfg.output,
+            config=OmegaConf.to_container(cfg, resolve=True),
+            resume=cfg.checkpoint.auto_resume)
+    else:
+        wandb = None
+    # waiting wandb init
+    dist.barrier()
+    dataset_train, dataset_val, \
+        data_loader_train, data_loader_val = build_loader(cfg.data)
+    data_loader_seg = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
+
+    logger = get_logger()
+
+    logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
+    model = build_model(cfg.model)
+    model.cuda()
+    logger.info(str(model))
+
+    optimizer = build_optimizer(cfg.train, model)
+    if cfg.train.amp_opt_level != 'O0':
+        model, optimizer = amp.initialize(model, optimizer, opt_level=cfg.train.amp_opt_level)
+    model = MMDistributedDataParallel(model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
+    model_without_ddp = model.module
+
+    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    logger.info(f'number of params: {n_parameters}')
+    lr_scheduler = build_scheduler(cfg.train, optimizer, len(data_loader_train))
+
+    if cfg.checkpoint.auto_resume:
+        resume_file = auto_resume_helper(cfg.output)
+        if resume_file:
+            if cfg.checkpoint.resume:
+                logger.warning(f'auto-resume changing resume file from {cfg.checkpoint.resume} to {resume_file}')
+            with read_write(cfg):
+                cfg.checkpoint.resume = resume_file
+            logger.info(f'auto resuming from {resume_file}')
+        else:
+            logger.info(f'no checkpoint found in {cfg.output}, ignoring auto resume')
+
+    max_accuracy = max_miou = 0.0
+    max_metrics = {'max_accuracy': max_accuracy, 'max_miou': max_miou}
+
+    if cfg.checkpoint.resume:
+        max_metrics = load_checkpoint(cfg, model_without_ddp, optimizer, lr_scheduler)
+        max_accuracy, max_miou = max_metrics['max_accuracy'], max_metrics['max_miou']
+        if 'cls' in cfg.evaluate.task:
+            acc1, acc5, loss = validate_cls(cfg, data_loader_val, model)
+            logger.info(f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%')
+        if 'seg' in cfg.evaluate.task:
+            miou = validate_seg(cfg, data_loader_seg, model)
+            logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
+        if cfg.evaluate.eval_only:
+            return
+
+    logger.info('Start training')
+    start_time = time.time()
+    for epoch in range(cfg.train.start_epoch, cfg.train.epochs):
+        loss_train_dict = train_one_epoch(cfg, model, data_loader_train, optimizer, epoch, lr_scheduler)
+        if dist.get_rank() == 0 and (epoch % cfg.checkpoint.save_freq == 0 or epoch == (cfg.train.epochs - 1)):
+            save_checkpoint(cfg, epoch, model_without_ddp, {
+                'max_accuracy': max_accuracy,
+                'max_miou': max_miou
+            }, optimizer, lr_scheduler)
+        dist.barrier()
+        loss_train = loss_train_dict['total_loss']
+        logger.info(f'Avg loss of the network on the {len(dataset_train)} train images: {loss_train:.2f}')
+
+        # evaluate
+        if (epoch % cfg.evaluate.eval_freq == 0 or epoch == (cfg.train.epochs - 1)):
+            if 'cls' in cfg.evaluate.task:
+                acc1, acc5, loss = validate_cls(cfg, data_loader_val, model)
+                logger.info(f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%')
+                max_metrics['max_accuracy'] = max(max_metrics['max_accuracy'], acc1)
+                if cfg.evaluate.cls.save_best and dist.get_rank() == 0 and acc1 > max_accuracy:
+                    save_checkpoint(
+                        cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_acc1')
+                dist.barrier()
+                max_accuracy = max_metrics['max_accuracy']
+                logger.info(f'Max accuracy: {max_accuracy:.2f}%')
+            if 'seg' in cfg.evaluate.task:
+                miou = validate_seg(cfg, data_loader_seg, model)
+                logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
+                max_metrics['max_miou'] = max(max_metrics['max_miou'], miou)
+                if cfg.evaluate.seg.save_best and dist.get_rank() == 0 and miou > max_miou:
+                    save_checkpoint(
+                        cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_miou')
+                dist.barrier()
+                max_miou = max_metrics['max_miou']
+                logger.info(f'Max mIoU: {max_miou:.2f}%')
+
+        if wandb is not None:
+            log_stat = {f'epoch/train_{k}': v for k, v in loss_train_dict.items()}
+            log_stat.update({
+                'epoch/val_acc1': acc1,
+                'epoch/val_acc5': acc5,
+                'epoch/val_loss': loss,
+                'epoch/val_miou': miou,
+                'epoch/epoch': epoch,
+                'epoch/n_parameters': n_parameters
+            })
+            wandb.log(log_stat)
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    logger.info('Training time {}'.format(total_time_str))
+    dist.barrier()
+
+
+def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler):
+    logger = get_logger()
+    dist.barrier()
+    model.train()
+    optimizer.zero_grad()
+    if config.wandb and dist.get_rank() == 0:
+        import wandb
+    else:
+        wandb = None
+
+    num_steps = len(data_loader)
+    batch_time = AverageMeter()
+    loss_meter = AverageMeter()
+    norm_meter = AverageMeter()
+    log_vars_meters = defaultdict(AverageMeter)
+
+    start = time.time()
+    end = time.time()
+    for idx, samples in enumerate(data_loader):
+
+        batch_size = config.data.batch_size
+
+        losses = model(**samples)
+
+        loss, log_vars = parse_losses(losses)
+
+        if config.train.accumulation_steps > 1:
+            loss = loss / config.train.accumulation_steps
+            if config.train.amp_opt_level != 'O0':
+                with amp.scale_loss(loss, optimizer) as scaled_loss:
+                    scaled_loss.backward()
+                if config.train.clip_grad:
+                    grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.train.clip_grad)
+                else:
+                    grad_norm = get_grad_norm(amp.master_params(optimizer))
+            else:
+                loss.backward()
+                if config.train.clip_grad:
+                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad)
+                else:
+                    grad_norm = get_grad_norm(model.parameters())
+            if (idx + 1) % config.train.accumulation_steps == 0:
+                optimizer.step()
+                optimizer.zero_grad()
+                lr_scheduler.step_update(epoch * num_steps + idx)
+        else:
+            optimizer.zero_grad()
+            if config.train.amp_opt_level != 'O0':
+                with amp.scale_loss(loss, optimizer) as scaled_loss:
+                    scaled_loss.backward()
+                if config.train.clip_grad:
+                    grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.train.clip_grad)
+                else:
+                    grad_norm = get_grad_norm(amp.master_params(optimizer))
+            else:
+                loss.backward()
+                if config.train.clip_grad:
+                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad)
+                else:
+                    grad_norm = get_grad_norm(model.parameters())
+            optimizer.step()
+            lr_scheduler.step_update(epoch * num_steps + idx)
+
+        torch.cuda.synchronize()
+
+        loss_meter.update(loss.item(), batch_size)
+        for loss_name in log_vars:
+            log_vars_meters[loss_name].update(log_vars[loss_name], batch_size)
+        norm_meter.update(grad_norm)
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if idx % config.print_freq == 0:
+            lr = optimizer.param_groups[0]['lr']
+            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
+            etas = batch_time.avg * (num_steps - idx)
+            log_vars_str = '\t'.join(f'{n} {m.val:.4f} ({m.avg:.4f})' for n, m in log_vars_meters.items())
+            logger.info(f'Train: [{epoch}/{config.train.epochs}][{idx}/{num_steps}]\t'
+                        f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
+                        f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
+                        f'total_loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
+                        f'{log_vars_str}\t'
+                        f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
+                        f'mem {memory_used:.0f}MB')
+            if wandb is not None:
+                log_stat = {f'iter/train_{n}': m.avg for n, m in log_vars_meters.items()}
+                log_stat['iter/train_total_loss'] = loss_meter.avg
+                log_stat['iter/learning_rate'] = lr
+                wandb.log(log_stat)
+
+    epoch_time = time.time() - start
+    logger.info(f'EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}')
+    result_dict = dict(total_loss=loss_meter.avg)
+    for n, m in log_vars_meters.items():
+        result_dict[n] = m.avg
+    dist.barrier()
+    return result_dict
+
+
+@torch.no_grad()
+def validate_cls(config, data_loader, model):
+    logger = get_logger()
+    dist.barrier()
+    criterion = torch.nn.CrossEntropyLoss()
+    model.eval()
+
+    batch_time = AverageMeter()
+    loss_meter = AverageMeter()
+    acc1_meter = AverageMeter()
+    acc5_meter = AverageMeter()
+
+    text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
+
+    end = time.time()
+    logger.info('Building zero shot classifier')
+    text_embedding = data2cuda(
+        model.module.build_text_embedding(
+            build_dataset_class_tokens(text_transform, config.evaluate.cls.template, imagenet_classes)))
+    logger.info('Zero shot classifier built')
+    for idx, samples in enumerate(data_loader):
+        target = samples.pop('target').data[0].cuda()
+        target = data2cuda(target)
+
+        # compute output
+        output = model(**samples, text=text_embedding)
+
+        # measure accuracy and record loss
+        loss = criterion(output, target)
+        acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+        acc1 = reduce_tensor(acc1)
+        acc5 = reduce_tensor(acc5)
+        loss = reduce_tensor(loss)
+
+        loss_meter.update(loss.item(), target.size(0))
+        acc1_meter.update(acc1.item(), target.size(0))
+        acc5_meter.update(acc5.item(), target.size(0))
+
+        # measure elapsed time
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if idx % config.print_freq == 0:
+            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
+            logger.info(f'Test: [{idx}/{len(data_loader)}]\t'
+                        f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
+                        f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
+                        f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
+                        f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
+                        f'Mem {memory_used:.0f}MB')
+    logger.info('Clearing zero shot classifier')
+    torch.cuda.empty_cache()
+    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
+    dist.barrier()
+    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
+
+
+@torch.no_grad()
+def validate_seg(config, data_loader, model):
+    logger = get_logger()
+    dist.barrier()
+    model.eval()
+
+    if hasattr(model, 'module'):
+        model_without_ddp = model.module
+    else:
+        model_without_ddp = model
+
+    text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
+    seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
+
+    mmddp_model = MMDistributedDataParallel(
+        seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
+    mmddp_model.eval()
+    results = multi_gpu_test(
+        model=mmddp_model,
+        data_loader=data_loader,
+        tmpdir=None,
+        gpu_collect=True,
+        efficient_test=False,
+        pre_eval=True,
+        format_only=False)
+
+    if dist.get_rank() == 0:
+        metric = [data_loader.dataset.evaluate(results, metric='mIoU')]
+    else:
+        metric = [None]
+    dist.broadcast_object_list(metric)
+    miou_result = metric[0]['mIoU'] * 100
+
+    torch.cuda.empty_cache()
+    logger.info(f'Eval Seg mIoU {miou_result:.2f}')
+    dist.barrier()
+    return miou_result
+
+
+def main():
+    args = parse_args()
+    cfg = get_config(args)
+
+    if cfg.train.amp_opt_level != 'O0':
+        assert amp is not None, 'amp not installed!'
+
+    # start faster ref: https://github.com/open-mmlab/mmdetection/pull/7036
+    mp.set_start_method('fork', force=True)
+    init_dist('pytorch')
+    rank, world_size = get_dist_info()
+    print(f'RANK and WORLD_SIZE in environ: {rank}/{world_size}')
+
+    dist.barrier()
+
+    set_random_seed(cfg.seed, use_rank_shift=True)
+    cudnn.benchmark = True
+
+    os.makedirs(cfg.output, exist_ok=True)
+    logger = get_logger(cfg)
+
+    # linear scale the learning rate according to total batch size, may not be optimal
+    linear_scaled_lr = cfg.train.base_lr * cfg.data.batch_size * world_size / 4096.0
+    linear_scaled_warmup_lr = cfg.train.warmup_lr * cfg.data.batch_size * world_size / 4096.0
+    linear_scaled_min_lr = cfg.train.min_lr * cfg.data.batch_size * world_size / 4096.0
+
+    # gradient accumulation also need to scale the learning rate
+    if cfg.train.accumulation_steps > 1:
+        linear_scaled_lr = linear_scaled_lr * cfg.train.accumulation_steps
+        linear_scaled_warmup_lr = linear_scaled_warmup_lr * cfg.train.accumulation_steps
+        linear_scaled_min_lr = linear_scaled_min_lr * cfg.train.accumulation_steps
+
+    with read_write(cfg):
+        logger.info(f'Scale base_lr from {cfg.train.base_lr} to {linear_scaled_lr}')
+        logger.info(f'Scale warmup_lr from {cfg.train.warmup_lr} to {linear_scaled_warmup_lr}')
+        logger.info(f'Scale min_lr from {cfg.train.min_lr} to {linear_scaled_min_lr}')
+        cfg.train.base_lr = linear_scaled_lr
+        cfg.train.warmup_lr = linear_scaled_warmup_lr
+        cfg.train.min_lr = linear_scaled_min_lr
+
+    if dist.get_rank() == 0:
+        path = os.path.join(cfg.output, 'config.json')
+        OmegaConf.save(cfg, path)
+        logger.info(f'Full config saved to {path}')
+
+    # log env info
+    env_info_dict = collect_env()
+    env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
+    dash_line = '-' * 60 + '\n'
+    logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line)
+
+    logger.info(f'Git hash: {get_git_hash(digits=7)}')
+
+    # print config
+    logger.info(OmegaConf.to_yaml(cfg))
+
+    train(cfg)
+    dist.barrier()
+
+
+if __name__ == '__main__':
+    main()

+ 194 - 0
main_seg.py

@@ -0,0 +1,194 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# ------------------------------------------------------------------------------
+
+import argparse
+import os
+import os.path as osp
+
+import mmcv
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from datasets import build_text_transform
+from main_group_vit import validate_seg
+from mmcv.image import tensor2imgs
+from mmcv.parallel import MMDistributedDataParallel
+from mmcv.runner import set_random_seed
+from models import build_model
+from omegaconf import OmegaConf, read_write
+from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
+from utils import get_config, get_logger, load_checkpoint
+
+try:
+    # noinspection PyUnresolvedReferences
+    from apex import amp
+except ImportError:
+    amp = None
+
+
+def parse_args():
+    parser = argparse.ArgumentParser('GroupViT segmentation evaluation and visualization')
+    parser.add_argument(
+        '--cfg',
+        type=str,
+        required=True,
+        help='path to config file',
+    )
+    parser.add_argument(
+        '--opts',
+        help="Modify config options by adding 'KEY VALUE' pairs. ",
+        default=None,
+        nargs='+',
+    )
+
+    parser.add_argument('--resume', help='resume from checkpoint')
+    parser.add_argument(
+        '--output', type=str, help='root of output folder, '
+        'the full path is <output>/<model_name>/<tag>')
+    parser.add_argument('--tag', help='tag of experiment')
+    parser.add_argument(
+        '--vis',
+        help='Specify the visualization mode, '
+        'could be a list, support input, pred, input_seg, input_pred_seg_label, all_groups, first_group, last_group',
+        default=None,
+        nargs='+')
+
+    # distributed training
+    parser.add_argument('--local_rank', type=int, required=True, help='local rank for DistributedDataParallel')
+
+    args = parser.parse_args()
+
+    return args
+
+
+def inference(cfg):
+    logger = get_logger()
+    data_loader = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
+    dataset = data_loader.dataset
+
+    logger.info(f'Evaluating dataset: {dataset}')
+
+    logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
+    model = build_model(cfg.model)
+    model.cuda()
+    logger.info(str(model))
+
+    if cfg.train.amp_opt_level != 'O0':
+        model = amp.initialize(model, None, opt_level=cfg.train.amp_opt_level)
+
+    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    logger.info(f'number of params: {n_parameters}')
+
+    load_checkpoint(cfg, model, None, None)
+
+    if 'seg' in cfg.evaluate.task:
+        miou = validate_seg(cfg, data_loader, model)
+        logger.info(f'mIoU of the network on the {len(data_loader.dataset)} test images: {miou:.2f}%')
+    else:
+        logger.info('No segmentation evaluation specified')
+
+    if cfg.vis:
+        vis_seg(cfg, data_loader, model, cfg.vis)
+
+
+@torch.no_grad()
+def vis_seg(config, data_loader, model, vis_modes):
+    dist.barrier()
+    model.eval()
+
+    if hasattr(model, 'module'):
+        model_without_ddp = model.module
+    else:
+        model_without_ddp = model
+
+    text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
+    seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
+
+    mmddp_model = MMDistributedDataParallel(
+        seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
+    mmddp_model.eval()
+    model = mmddp_model.module
+    device = next(model.parameters()).device
+    dataset = data_loader.dataset
+
+    if dist.get_rank() == 0:
+        prog_bar = mmcv.ProgressBar(len(dataset))
+    loader_indices = data_loader.batch_sampler
+    for batch_indices, data in zip(loader_indices, data_loader):
+        with torch.no_grad():
+            result = mmddp_model(return_loss=False, **data)
+
+        img_tensor = data['img'][0]
+        img_metas = data['img_metas'][0].data[0]
+        imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
+        assert len(imgs) == len(img_metas)
+
+        for batch_idx, img, img_meta in zip(batch_indices, imgs, img_metas):
+            h, w, _ = img_meta['img_shape']
+            img_show = img[:h, :w, :]
+
+            ori_h, ori_w = img_meta['ori_shape'][:-1]
+            img_show = mmcv.imresize(img_show, (ori_w, ori_h))
+
+            for vis_mode in vis_modes:
+                out_file = osp.join(config.output, 'vis_imgs', vis_mode, f'{batch_idx:04d}.jpg')
+                model.show_result(img_show, img_tensor.to(device), result, out_file, vis_mode)
+            if dist.get_rank() == 0:
+                batch_size = len(result) * dist.get_world_size()
+                for _ in range(batch_size):
+                    prog_bar.update()
+
+
+def main():
+    args = parse_args()
+    cfg = get_config(args)
+
+    if cfg.train.amp_opt_level != 'O0':
+        assert amp is not None, 'amp not installed!'
+
+    with read_write(cfg):
+        cfg.evaluate.eval_only = True
+
+    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        rank = int(os.environ['RANK'])
+        world_size = int(os.environ['WORLD_SIZE'])
+        print(f'RANK and WORLD_SIZE in environ: {rank}/{world_size}')
+    else:
+        rank = -1
+        world_size = -1
+    torch.cuda.set_device(cfg.local_rank)
+
+    dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
+
+    dist.barrier()
+
+    set_random_seed(cfg.seed, use_rank_shift=True)
+    cudnn.benchmark = True
+
+    os.makedirs(cfg.output, exist_ok=True)
+    logger = get_logger(cfg)
+
+    if dist.get_rank() == 0:
+        path = os.path.join(cfg.output, 'config.json')
+        OmegaConf.save(cfg, path)
+        logger.info(f'Full config saved to {path}')
+
+    # print config
+    logger.info(OmegaConf.to_yaml(cfg))
+
+    inference(cfg)
+    dist.barrier()
+
+
+if __name__ == '__main__':
+    main()

+ 19 - 0
models/__init__.py

@@ -0,0 +1,19 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from .builder import build_model
+from .group_vit import GroupViT
+from .multi_label_contrastive import MultiLabelContrastive
+from .transformer import TextTransformer
+
+__all__ = ['build_model', 'MultiLabelContrastive', 'GroupViT', 'TextTransformer']

+ 24 - 0
models/builder.py

@@ -0,0 +1,24 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from mmcv.utils import Registry
+from omegaconf import OmegaConf
+
+MODELS = Registry('model')
+
+
+def build_model(config):
+
+    model = MODELS.build(OmegaConf.to_container(config, resolve=True))
+
+    return model

+ 882 - 0
models/group_vit.py

@@ -0,0 +1,882 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from einops import rearrange
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from .builder import MODELS
+from .misc import Result, interpolate_pos_encoding
+
+
+class Mlp(nn.Module):
+
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class MixerMlp(Mlp):
+
+    def forward(self, x):
+        return super().forward(x.transpose(1, 2)).transpose(1, 2)
+
+
+def hard_softmax(logits, dim):
+    y_soft = logits.softmax(dim)
+    # Straight through.
+    index = y_soft.max(dim, keepdim=True)[1]
+    y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
+    ret = y_hard - y_soft.detach() + y_soft
+
+    return ret
+
+
+def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> torch.Tensor:
+    # _gumbels = (-torch.empty_like(
+    #     logits,
+    #     memory_format=torch.legacy_contiguous_format).exponential_().log()
+    #             )  # ~Gumbel(0,1)
+    # more stable https://github.com/pytorch/pytorch/issues/41663
+    gumbel_dist = torch.distributions.gumbel.Gumbel(
+        torch.tensor(0., device=logits.device, dtype=logits.dtype),
+        torch.tensor(1., device=logits.device, dtype=logits.dtype))
+    gumbels = gumbel_dist.sample(logits.shape)
+
+    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
+    y_soft = gumbels.softmax(dim)
+
+    if hard:
+        # Straight through.
+        index = y_soft.max(dim, keepdim=True)[1]
+        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
+        ret = y_hard - y_soft.detach() + y_soft
+    else:
+        # Reparametrization trick.
+        ret = y_soft
+    return ret
+
+
+class AssignAttention(nn.Module):
+
+    def __init__(self,
+                 dim,
+                 num_heads=1,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 attn_drop=0.,
+                 proj_drop=0.,
+                 hard=True,
+                 gumbel=False,
+                 gumbel_tau=1.,
+                 sum_assign=False,
+                 assign_eps=1.):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim**-0.5
+
+        self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
+        self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
+        self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+        self.hard = hard
+        self.gumbel = gumbel
+        self.gumbel_tau = gumbel_tau
+        self.sum_assign = sum_assign
+        self.assign_eps = assign_eps
+
+    def get_attn(self, attn, gumbel=None, hard=None):
+
+        if gumbel is None:
+            gumbel = self.gumbel
+
+        if hard is None:
+            hard = self.hard
+
+        attn_dim = -2
+        if gumbel and self.training:
+            attn = gumbel_softmax(attn, dim=attn_dim, hard=hard, tau=self.gumbel_tau)
+        else:
+            if hard:
+                attn = hard_softmax(attn, dim=attn_dim)
+            else:
+                attn = F.softmax(attn, dim=attn_dim)
+
+        return attn
+
+    def forward(self, query, key=None, *, value=None, return_attn=False):
+        B, N, C = query.shape
+        if key is None:
+            key = query
+        if value is None:
+            value = key
+        S = key.size(1)
+        # [B, nh, N, C//nh]
+        q = rearrange(self.q_proj(query), 'b n (h c)-> b h n c', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
+        # [B, nh, S, C//nh]
+        k = rearrange(self.k_proj(key), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
+        # [B, nh, S, C//nh]
+        v = rearrange(self.v_proj(value), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
+
+        # [B, nh, N, S]
+        raw_attn = (q @ k.transpose(-2, -1)) * self.scale
+
+        attn = self.get_attn(raw_attn)
+        if return_attn:
+            hard_attn = attn.clone()
+            soft_attn = self.get_attn(raw_attn, gumbel=False, hard=False)
+            attn_dict = {'hard': hard_attn, 'soft': soft_attn}
+        else:
+            attn_dict = None
+
+        if not self.sum_assign:
+            attn = attn / (attn.sum(dim=-1, keepdim=True) + self.assign_eps)
+        attn = self.attn_drop(attn)
+        assert attn.shape == (B, self.num_heads, N, S)
+
+        # [B, nh, N, C//nh] <- [B, nh, N, S] @ [B, nh, S, C//nh]
+        out = rearrange(attn @ v, 'b h n c -> b n (h c)', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
+
+        out = self.proj(out)
+        out = self.proj_drop(out)
+        return out, attn_dict
+
+    def extra_repr(self):
+        return f'num_heads: {self.num_heads}, \n' \
+               f'hard: {self.hard}, \n' \
+               f'gumbel: {self.gumbel}, \n' \
+               f'sum_assign={self.sum_assign}, \n' \
+               f'gumbel_tau: {self.gumbel_tau}, \n' \
+               f'assign_eps: {self.assign_eps}'
+
+
+class GroupingBlock(nn.Module):
+    """Grouping Block to group similar segments together.
+
+    Args:
+        dim (int): Dimension of the input.
+        out_dim (int): Dimension of the output.
+        num_heads (int): Number of heads in the grouping attention.
+        num_output_group (int): Number of output groups.
+        norm_layer (nn.Module): Normalization layer to use.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+        hard (bool): Whether to use hard or soft assignment. Default: True
+        gumbel (bool): Whether to use gumbel softmax. Default: True
+        sum_assign (bool): Whether to sum assignment or average. Default: False
+        assign_eps (float): Epsilon to avoid divide by zero. Default: 1
+        gum_tau (float): Temperature for gumbel softmax. Default: 1
+    """
+
+    def __init__(self,
+                 *,
+                 dim,
+                 out_dim,
+                 num_heads,
+                 num_group_token,
+                 num_output_group,
+                 norm_layer,
+                 mlp_ratio=(0.5, 4.0),
+                 hard=True,
+                 gumbel=True,
+                 sum_assign=False,
+                 assign_eps=1.,
+                 gumbel_tau=1.):
+        super(GroupingBlock, self).__init__()
+        self.dim = dim
+        self.hard = hard
+        self.gumbel = gumbel
+        self.sum_assign = sum_assign
+        self.num_output_group = num_output_group
+        # norm on group_tokens
+        self.norm_tokens = norm_layer(dim)
+        tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
+        self.mlp_inter = Mlp(num_group_token, tokens_dim, num_output_group)
+        self.norm_post_tokens = norm_layer(dim)
+        # norm on x
+        self.norm_x = norm_layer(dim)
+        self.pre_assign_attn = CrossAttnBlock(
+            dim=dim, num_heads=num_heads, mlp_ratio=4, qkv_bias=True, norm_layer=norm_layer, post_norm=True)
+
+        self.assign = AssignAttention(
+            dim=dim,
+            num_heads=1,
+            qkv_bias=True,
+            hard=hard,
+            gumbel=gumbel,
+            gumbel_tau=gumbel_tau,
+            sum_assign=sum_assign,
+            assign_eps=assign_eps)
+        self.norm_new_x = norm_layer(dim)
+        self.mlp_channels = Mlp(dim, channels_dim, out_dim)
+        if out_dim is not None and dim != out_dim:
+            self.reduction = nn.Sequential(norm_layer(dim), nn.Linear(dim, out_dim, bias=False))
+        else:
+            self.reduction = nn.Identity()
+
+    def extra_repr(self):
+        return f'hard={self.hard}, \n' \
+               f'gumbel={self.gumbel}, \n' \
+               f'sum_assign={self.sum_assign}, \n' \
+               f'num_output_group={self.num_output_group}, \n '
+
+    def project_group_token(self, group_tokens):
+        """
+        Args:
+            group_tokens (torch.Tensor): group tokens, [B, S_1, C]
+
+        inter_weight (torch.Tensor): [B, S_2, S_1], S_2 is the new number of
+            group tokens, it's already softmaxed along dim=-1
+
+        Returns:
+            projected_group_tokens (torch.Tensor): [B, S_2, C]
+        """
+        # [B, S_2, C] <- [B, S_1, C]
+        projected_group_tokens = self.mlp_inter(group_tokens.transpose(1, 2)).transpose(1, 2)
+        projected_group_tokens = self.norm_post_tokens(projected_group_tokens)
+        return projected_group_tokens
+
+    def forward(self, x, group_tokens, return_attn=False):
+        """
+        Args:
+            x (torch.Tensor): image tokens, [B, L, C]
+            group_tokens (torch.Tensor): group tokens, [B, S_1, C]
+            return_attn (bool): whether to return attention map
+
+        Returns:
+            new_x (torch.Tensor): [B, S_2, C], S_2 is the new number of
+                group tokens
+        """
+        group_tokens = self.norm_tokens(group_tokens)
+        x = self.norm_x(x)
+        # [B, S_2, C]
+        projected_group_tokens = self.project_group_token(group_tokens)
+        projected_group_tokens = self.pre_assign_attn(projected_group_tokens, x)
+        new_x, attn_dict = self.assign(projected_group_tokens, x, return_attn=return_attn)
+        new_x += projected_group_tokens
+
+        new_x = self.reduction(new_x) + self.mlp_channels(self.norm_new_x(new_x))
+
+        return new_x, attn_dict
+
+
+class Attention(nn.Module):
+
+    def __init__(self,
+                 dim,
+                 num_heads,
+                 out_dim=None,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 attn_drop=0.,
+                 proj_drop=0.,
+                 qkv_fuse=False):
+        super().__init__()
+        if out_dim is None:
+            out_dim = dim
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim**-0.5
+        self.qkv_fuse = qkv_fuse
+
+        if qkv_fuse:
+            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        else:
+            self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
+            self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
+            self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, out_dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def extra_repr(self):
+        return f'num_heads={self.num_heads}, \n' \
+               f'qkv_bias={self.scale}, \n' \
+               f'qkv_fuse={self.qkv_fuse}'
+
+    def forward(self, query, key=None, *, value=None, mask=None):
+        if self.qkv_fuse:
+            assert key is None
+            assert value is None
+            x = query
+            B, N, C = x.shape
+            S = N
+            # [3, B, nh, N, C//nh]
+            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+            # [B, nh, N, C//nh]
+            q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+        else:
+            B, N, C = query.shape
+            if key is None:
+                key = query
+            if value is None:
+                value = key
+            S = key.size(1)
+            # [B, nh, N, C//nh]
+            q = rearrange(self.q_proj(query), 'b n (h c)-> b h n c', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
+            # [B, nh, S, C//nh]
+            k = rearrange(self.k_proj(key), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
+            # [B, nh, S, C//nh]
+            v = rearrange(self.v_proj(value), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
+
+        # [B, nh, N, S]
+        attn = (q @ k.transpose(-2, -1)) * self.scale
+        if mask is not None:
+            attn = attn + mask.unsqueeze(dim=1)
+            attn = attn.softmax(dim=-1)
+        else:
+            attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+        assert attn.shape == (B, self.num_heads, N, S)
+
+        # [B, nh, N, C//nh] -> [B, N, C]
+        # out = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        out = rearrange(attn @ v, 'b h n c -> b n (h c)', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
+        out = self.proj(out)
+        out = self.proj_drop(out)
+        return out
+
+
+class CrossAttnBlock(nn.Module):
+
+    def __init__(self,
+                 dim,
+                 num_heads,
+                 mlp_ratio=4.,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 act_layer=nn.GELU,
+                 norm_layer=nn.LayerNorm,
+                 post_norm=False):
+        super().__init__()
+        if post_norm:
+            self.norm_post = norm_layer(dim)
+            self.norm_q = nn.Identity()
+            self.norm_k = nn.Identity()
+        else:
+            self.norm_q = norm_layer(dim)
+            self.norm_k = norm_layer(dim)
+            self.norm_post = nn.Identity()
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+    def forward(self, query, key, *, mask=None):
+        x = query
+        x = x + self.drop_path(self.attn(self.norm_q(query), self.norm_k(key), mask=mask))
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        x = self.norm_post(x)
+        return x
+
+
+class AttnBlock(nn.Module):
+
+    def __init__(self,
+                 dim,
+                 num_heads,
+                 mlp_ratio=4.,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 act_layer=nn.GELU,
+                 norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim,
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            qk_scale=qk_scale,
+            attn_drop=attn_drop,
+            proj_drop=drop,
+            qkv_fuse=True)
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+    def forward(self, x, mask=None):
+        x = x + self.drop_path(self.attn(self.norm1(x), mask=mask))
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+
+class GroupingLayer(nn.Module):
+    """A Transformer layer with Grouping Block for one stage.
+
+    Args:
+        dim (int): Number of input channels.
+        num_input_token (int): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer.
+            In GroupViT setting, Grouping Block serves as the downsampling layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+        group_projector (nn.Module | None, optional): Projector for the grouping layer. Default: None.
+        zero_init_group_token (bool): Whether to initialize the grouping token to 0. Default: False.
+    """
+
+    def __init__(self,
+                 dim,
+                 num_input_token,
+                 depth,
+                 num_heads,
+                 num_group_token,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 norm_layer=nn.LayerNorm,
+                 downsample=None,
+                 use_checkpoint=False,
+                 group_projector=None,
+                 zero_init_group_token=False):
+
+        super().__init__()
+        self.dim = dim
+        self.input_length = num_input_token
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+        self.num_group_token = num_group_token
+        if num_group_token > 0:
+            self.group_token = nn.Parameter(torch.zeros(1, num_group_token, dim))
+            if not zero_init_group_token:
+                trunc_normal_(self.group_token, std=.02)
+        else:
+            self.group_token = None
+
+        # build blocks
+        self.depth = depth
+        blocks = []
+        for i in range(depth):
+            blocks.append(
+                AttnBlock(
+                    dim=dim,
+                    num_heads=num_heads,
+                    mlp_ratio=mlp_ratio,
+                    qkv_bias=qkv_bias,
+                    qk_scale=qk_scale,
+                    drop=drop,
+                    attn_drop=attn_drop,
+                    drop_path=drop_path[i],
+                    norm_layer=norm_layer))
+        self.blocks = nn.ModuleList(blocks)
+
+        self.downsample = downsample
+        self.input_resolution = num_input_token
+        self.use_checkpoint = use_checkpoint
+
+        self.group_projector = group_projector
+
+    @property
+    def with_group_token(self):
+        return self.group_token is not None
+
+    def extra_repr(self):
+        return f'dim={self.dim}, \n' \
+               f'input_resolution={self.input_resolution}, \n' \
+               f'depth={self.depth}, \n' \
+               f'num_group_token={self.num_group_token}, \n'
+
+    def split_x(self, x):
+        if self.with_group_token:
+            return x[:, :-self.num_group_token], x[:, -self.num_group_token:]
+        else:
+            return x, None
+
+    def concat_x(self, x, group_token=None):
+        if group_token is None:
+            return x
+        return torch.cat([x, group_token], dim=1)
+
+    def forward(self, x, prev_group_token=None, return_attn=False):
+        """
+        Args:
+            x (torch.Tensor): image tokens, [B, L, C]
+            prev_group_token (torch.Tensor): group tokens, [B, S_1, C]
+            return_attn (bool): whether to return attention maps
+        """
+        if self.with_group_token:
+            group_token = self.group_token.expand(x.size(0), -1, -1)
+            if self.group_projector is not None:
+                group_token = group_token + self.group_projector(prev_group_token)
+        else:
+            group_token = None
+
+        B, L, C = x.shape
+        cat_x = self.concat_x(x, group_token)
+        for blk_idx, blk in enumerate(self.blocks):
+            if self.use_checkpoint:
+                cat_x = checkpoint.checkpoint(blk, cat_x)
+            else:
+                cat_x = blk(cat_x)
+
+        x, group_token = self.split_x(cat_x)
+
+        attn_dict = None
+        if self.downsample is not None:
+            x, attn_dict = self.downsample(x, group_token, return_attn=return_attn)
+
+        return x, group_token, attn_dict
+
+
+class PatchEmbed(nn.Module):
+    """Image to Patch Embedding."""
+
+    def __init__(self, img_size=224, kernel_size=7, stride=4, padding=2, in_chans=3, embed_dim=96, norm_layer=None):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        kernel_size = to_2tuple(kernel_size)
+        stride = to_2tuple(stride)
+        padding = to_2tuple(padding)
+        self.img_size = img_size
+        self.patches_resolution = (
+            int((img_size[1] + 2 * padding[1] - kernel_size[1]) / stride[1] + 1),
+            int((img_size[0] + 2 * padding[0] - kernel_size[0]) / stride[0] + 1),
+        )
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
+        if norm_layer is not None:
+            self.norm = norm_layer(embed_dim)
+        else:
+            self.norm = None
+
+    @property
+    def num_patches(self):
+        return self.patches_resolution[1] * self.patches_resolution[0]
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        if self.training:
+            # FIXME look at relaxing size constraints
+            assert H == self.img_size[0] and W == self.img_size[1], \
+                f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x)
+        hw_shape = x.shape[2:]
+        x = x.flatten(2).transpose(1, 2)
+        if self.norm is not None:
+            x = self.norm(x)
+        return x, hw_shape
+
+
+@MODELS.register_module()
+class GroupViT(nn.Module):
+    r""" Group Vision Transformer
+        A PyTorch impl of : `GroupViT: Semantic Segmentation Emerges from Text Supervision`  -
+          https://arxiv.org/pdf/2202.11094.pdf
+
+    Args:
+        img_size (int | tuple[int]): Input image size. Default 224
+        patch_size (int | tuple[int]): Patch size. Default: 4
+        in_chans (int): Number of input image channels. Default: 3
+        num_classes (int): Number of classes for classification head. Default: 0
+        embed_dim (int): Patch embedding dimension. Default: 384
+        embed_factors (list[int]): Embedding dim multipliers for each stage.
+        depths (list[int]): Depth of each stage
+        num_heads (list[int]): Number of heads for each stage
+        num_group_tokens (list[int]): Number of group tokens for each stage
+        num_output_group (list[int]): Number of output groups for each stage
+        hard_assignment (bool): Whether to use hard assignment or not. Default: True
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+        drop_rate (float): Dropout rate. Default: 0
+        attn_drop_rate (float): Attention dropout rate. Default: 0
+        drop_path_rate (float): Stochastic depth rate. Default: 0.1
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+        pos_embed_type (str): Type of positional embedding. Default: 'simple'
+        freeze_patch_embed (bool): Whether to freeze patch embedding. Default: False
+    """
+
+    def __init__(self,
+                 img_size=224,
+                 patch_size=16,
+                 in_chans=3,
+                 num_classes=0,
+                 embed_dim=384,
+                 embed_factors=[1, 1, 1],
+                 depths=[6, 3, 3],
+                 num_heads=[6, 6, 6],
+                 num_group_tokens=[64, 8, 0],
+                 num_output_groups=[64, 8],
+                 hard_assignment=True,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.1,
+                 patch_norm=True,
+                 use_checkpoint=False,
+                 pos_embed_type='simple',
+                 freeze_patch_embed=False):
+        super().__init__()
+        assert patch_size in [4, 8, 16]
+        self.num_classes = num_classes
+        assert len(embed_factors) == len(depths) == len(num_group_tokens)
+        assert all(_ == 0 for _ in num_heads) or len(depths) == len(num_heads)
+        assert len(depths) - 1 == len(num_output_groups)
+        self.num_layers = len(depths)
+        self.embed_dim = embed_dim
+        self.patch_norm = patch_norm
+        self.num_features = int(embed_dim * embed_factors[len(depths) - 1])
+        self.mlp_ratio = mlp_ratio
+        self.qkv_bias = qkv_bias
+        self.qk_scale = qk_scale
+        self.drop_rate = drop_rate
+        self.attn_drop_rate = attn_drop_rate
+        self.drop_path_rate = drop_path_rate
+        self.num_group_tokens = num_group_tokens
+        self.num_output_groups = num_output_groups
+        self.pos_embed_type = pos_embed_type
+        assert pos_embed_type in ['simple', 'fourier']
+
+        norm_layer = nn.LayerNorm
+
+        # split image into non-overlapping patches
+        self.patch_embed = PatchEmbed(
+            img_size=img_size,
+            kernel_size=patch_size,
+            stride=patch_size,
+            padding=0,
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+            norm_layer=norm_layer if self.patch_norm else None)
+        num_patches = self.patch_embed.num_patches
+        patches_resolution = self.patch_embed.patches_resolution
+        self.patches_resolution = patches_resolution
+
+        self.avgpool = nn.AdaptiveAvgPool1d(1)
+
+        if pos_embed_type == 'simple':
+            self.pos_embed = self.build_simple_position_embedding()
+        elif pos_embed_type == 'fourier':
+            self.pos_embed = self.build_2d_sincos_position_embedding()
+        else:
+            raise ValueError
+
+        if freeze_patch_embed:
+            for param in self.patch_embed.parameters():
+                param.requires_grad = False
+            self.pos_embed.requires_grad = False
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        # stochastic depth
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+
+        num_input_token = num_patches
+        num_output_token = num_input_token
+        # build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+
+            dim = int(embed_dim * embed_factors[i_layer])
+            downsample = None
+            if i_layer < self.num_layers - 1:
+                out_dim = embed_dim * embed_factors[i_layer + 1]
+                downsample = GroupingBlock(
+                    dim=dim,
+                    out_dim=out_dim,
+                    num_heads=num_heads[i_layer],
+                    num_group_token=num_group_tokens[i_layer],
+                    num_output_group=num_output_groups[i_layer],
+                    norm_layer=norm_layer,
+                    hard=hard_assignment,
+                    gumbel=hard_assignment)
+                num_output_token = num_output_groups[i_layer]
+
+            if i_layer > 0 and num_group_tokens[i_layer] > 0:
+                prev_dim = int(embed_dim * embed_factors[i_layer - 1])
+                group_projector = nn.Sequential(
+                    norm_layer(prev_dim),
+                    MixerMlp(num_group_tokens[i_layer - 1], prev_dim // 2, num_group_tokens[i_layer]))
+
+                if dim != prev_dim:
+                    group_projector = nn.Sequential(group_projector, norm_layer(prev_dim),
+                                                    nn.Linear(prev_dim, dim, bias=False))
+            else:
+                group_projector = None
+            layer = GroupingLayer(
+                dim=dim,
+                num_input_token=num_input_token,
+                depth=depths[i_layer],
+                num_heads=num_heads[i_layer],
+                num_group_token=num_group_tokens[i_layer],
+                mlp_ratio=self.mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+                norm_layer=norm_layer,
+                downsample=downsample,
+                use_checkpoint=use_checkpoint,
+                group_projector=group_projector,
+                # only zero init group token if we have a projection
+                zero_init_group_token=group_projector is not None)
+            self.layers.append(layer)
+            if i_layer < self.num_layers - 1:
+                num_input_token = num_output_token
+
+        self.norm = norm_layer(self.num_features)
+        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+        self.apply(self._init_weights)
+
+    def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
+        if self.pos_embed_type == 'simple' and 'pos_embed' in state_dict:
+            load_pos_embed = state_dict['pos_embed']
+            pos_embed = self.pos_embed
+            if load_pos_embed.shape != pos_embed.shape:
+                H_new = int(self.patch_embed.num_patches**0.5)
+                W_new = H_new
+                H_ori = int(load_pos_embed.shape[1]**0.5)
+                W_ori = H_ori
+                load_pos_embed = F.interpolate(
+                    rearrange(load_pos_embed, 'b (h w) c -> b c h w', h=H_ori, w=W_ori, b=1),
+                    size=(H_new, W_new),
+                    mode='bicubic',
+                    align_corners=False)
+                load_pos_embed = rearrange(load_pos_embed, 'b c h w -> b (h w) c', h=H_new, w=W_new)
+                state_dict['pos_embed'] = load_pos_embed
+        return super().load_state_dict(state_dict, strict)
+
+    def build_simple_position_embedding(self):
+        pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches, self.embed_dim))
+        trunc_normal_(pos_embed, std=.02)
+        return pos_embed
+
+    def build_2d_sincos_position_embedding(self, temperature=10000.):
+        h, w = self.patch_embed.patches_resolution
+        grid_w = torch.arange(w, dtype=torch.float32)
+        grid_h = torch.arange(h, dtype=torch.float32)
+        grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
+        assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
+        pos_dim = self.embed_dim // 4
+        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
+        omega = 1. / (temperature**omega)
+        out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
+        out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
+        pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
+
+        pos_embed = nn.Parameter(pos_emb)
+        pos_embed.requires_grad = False
+        return pos_embed
+
+    @property
+    def width(self):
+        return self.num_features
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def get_pos_embed(self, B, H, W):
+        if self.training:
+            return self.pos_embed
+        pos_embed = self.pos_embed
+        pos_embed = interpolate_pos_encoding(pos_embed, H, W)
+        return pos_embed
+
+    def forward_features(self, x, *, return_attn=False):
+        B = x.shape[0]
+        x, hw_shape = self.patch_embed(x)
+
+        x = x + self.get_pos_embed(B, *hw_shape)
+        x = self.pos_drop(x)
+
+        group_token = None
+        attn_dict_list = []
+        for layer in self.layers:
+            x, group_token, attn_dict = layer(x, group_token, return_attn=return_attn)
+            attn_dict_list.append(attn_dict)
+
+        x = self.norm(x)
+
+        return x, group_token, attn_dict_list
+
+    def forward_image_head(self, x):
+        """
+
+        Args:
+            x: shape [B, L, C]
+
+        Returns:
+
+        """
+        # [B, L, C]
+        x = self.avgpool(x.transpose(1, 2))  # B C 1
+        x = torch.flatten(x, 1)
+        x = self.head(x)
+
+        return x
+
+    def forward(self, x, *, return_feat=False, return_attn=False, as_dict=False):
+        x, group_token, attn_dicts = self.forward_features(x, return_attn=return_attn)
+        x_feat = x if return_feat else None
+
+        outs = Result(as_dict=as_dict)
+
+        outs.append(self.forward_image_head(x), name='x')
+
+        if return_feat:
+            outs.append(x_feat, name='feat')
+
+        if return_attn:
+            outs.append(attn_dicts, name='attn_dicts')
+
+        return outs.as_return()

+ 74 - 0
models/misc.py

@@ -0,0 +1,74 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import math
+
+import torch.nn.functional as F
+
+
+class Result:
+
+    def __init__(self, as_dict=False):
+        if as_dict:
+            self.outs = {}
+        else:
+            self.outs = []
+
+    @property
+    def as_dict(self):
+        return isinstance(self.outs, dict)
+
+    def append(self, element, name=None):
+        if self.as_dict:
+            assert name is not None
+            self.outs[name] = element
+        else:
+            self.outs.append(element)
+
+    def update(self, **kwargs):
+        if self.as_dict:
+            self.outs.update(**kwargs)
+        else:
+            for v in kwargs.values():
+                self.outs.append(v)
+
+    def as_output(self):
+        if self.as_dict:
+            return self.outs
+        else:
+            return tuple(self.outs)
+
+    def as_return(self):
+        outs = self.as_output()
+        if self.as_dict:
+            return outs
+        if len(outs) == 1:
+            return outs[0]
+        return outs
+
+
+def interpolate_pos_encoding(pos_embed, H, W):
+    num_patches = H * W
+
+    N = pos_embed.shape[1]
+    if num_patches == N and W == H:
+        return pos_embed
+    patch_pos_embed = pos_embed
+    dim = pos_embed.shape[-1]
+    patch_pos_embed = F.interpolate(
+        patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+        size=(H, W),
+        mode='bicubic',
+        align_corners=False)
+    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+    return patch_pos_embed

+ 302 - 0
models/multi_label_contrastive.py

@@ -0,0 +1,302 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import diffdist.functional as diff_dist
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from timm.loss import SoftTargetCrossEntropy
+
+from .builder import MODELS
+from .misc import Result
+
+
+def dist_collect(x):
+    """ collect all tensor from all GPUs
+    args:
+        x: shape (mini_batch, ...)
+    returns:
+        shape (mini_batch * num_gpu, ...)
+    """
+    x = x.contiguous()
+    out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
+    out_list = diff_dist.all_gather(out_list, x)
+    return torch.cat(out_list, dim=0).contiguous()
+
+
+class ProjectMLP(nn.Module):
+
+    def __init__(self, in_dim=256, inner_dim=4096, out_dim=256, num_layers=2):
+        super(ProjectMLP, self).__init__()
+        # hidden layers
+        linear_hidden = []
+        for i in range(num_layers - 1):
+            linear_hidden.append(nn.Conv1d(in_dim if i == 0 else inner_dim, inner_dim, kernel_size=1))
+            linear_hidden.append(nn.BatchNorm1d(inner_dim))
+            linear_hidden.append(nn.ReLU(inplace=True))
+        self.linear_hidden = nn.Sequential(*linear_hidden)
+
+        self.linear_out = nn.Conv1d(
+            in_dim if num_layers == 1 else inner_dim, out_dim, kernel_size=1) if num_layers >= 1 else nn.Identity()
+
+    def forward(self, x):
+        """
+
+        Args:
+            x (torch.Tensor): output of transformers, shape [B, L, C]
+
+        Returns:
+
+        """
+        assert x.ndim in [2, 3], x.ndim
+        add_dim = False
+        if x.ndim == 2:
+            # [B, C] -> [B, L, C]
+            x = x.unsqueeze(1)
+            add_dim = True
+
+        x = rearrange(x, 'b l c -> b c l')
+        x = self.linear_hidden(x)
+        x = self.linear_out(x)
+        x = rearrange(x, 'b c l -> b l c')
+
+        if add_dim:
+            x = x.squeeze(1)
+
+        return x
+
+
+@MODELS.register_module()
+class MultiLabelContrastive(nn.Module):
+
+    def __init__(self,
+                 img_encoder,
+                 text_encoder,
+                 output_dim=256,
+                 contrast_temperature=0.07,
+                 proj_num_layers=2,
+                 multi_label=0,
+                 share_temperature=False,
+                 multi_label_loss_weight=1.0):
+        super().__init__()
+
+        self.img_encoder = MODELS.build(img_encoder)
+        self.text_encoder = MODELS.build(text_encoder)
+
+        self.contrast_temperature = contrast_temperature
+        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
+        self.cross_entropy = nn.CrossEntropyLoss()
+        self.soft_cross_entropy = SoftTargetCrossEntropy()
+
+        self.proj_num_layers = proj_num_layers
+        self.multi_label = multi_label
+        if proj_num_layers > 0:
+            self.img_projector = ProjectMLP(
+                in_dim=self.img_encoder.width, num_layers=proj_num_layers, out_dim=output_dim)
+            self.text_projector = ProjectMLP(
+                in_dim=self.text_encoder.width, num_layers=proj_num_layers, out_dim=output_dim)
+            self.img_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.img_projector)
+            self.text_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.text_projector)
+
+        else:
+            self.img_projector = nn.Identity()
+            self.text_projector = nn.Identity()
+
+        self.share_temperature = share_temperature
+        if self.with_multi_label and not self.share_temperature:
+            self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
+        self.multi_label_loss_weight = multi_label_loss_weight
+
+    @property
+    def with_multi_label(self):
+        return self.multi_label > 0
+
+    def loss(self, image_x, text_x):
+
+        batch_size = image_x.shape[0]
+        # get label globally
+        labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device) + batch_size * dist.get_rank()
+
+        # [B, C]
+        image_x = F.normalize(image_x, dim=-1)
+        text_x = F.normalize(text_x, dim=-1)
+
+        logits_per_img = image_x @ dist_collect(text_x).t()
+        logits_per_text = text_x @ dist_collect(image_x).t()
+
+        logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
+        loss_img = self.cross_entropy(logits_per_img * logit_scale, labels)
+        loss_text = self.cross_entropy(logits_per_text * logit_scale, labels)
+
+        loss = 0.5 * (loss_img + loss_text)
+
+        return loss
+
+    def multi_label_loss(self, image_feat, text_feat):
+        """
+
+        Args:
+            image_feat (torch.Tensor): shape [B, L1, C]
+            text_feat (torch.Tensor): shape [B, L2, C]
+
+        Returns:
+
+        """
+        # [B, L1, C], L1 = 1
+        image_feat = F.normalize(image_feat, dim=-1)
+        # [B, L2, C]
+        text_feat = F.normalize(text_feat, dim=-1)
+
+        # [B, L1, L2]
+        dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')
+        # [B, L2, L1]
+        dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')
+
+        if self.share_temperature:
+            logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
+        else:
+            logit_scale = torch.clamp(self.multi_label_logit_scale.exp(), max=100)
+
+        batch = image_feat.shape[0]
+        img_len = image_feat.shape[1]
+        text_len = text_feat.shape[1]
+        # [B, L1, L2]
+        pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')
+        # [B, L2, L1]
+        pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')
+
+        image_x = rearrange(image_feat, 'b l c -> (b l) c')
+        text_x = rearrange(text_feat, 'b l c -> (b l) c')
+
+        logits_per_img = image_x @ dist_collect(text_x).t()
+        logits_per_text = text_x @ dist_collect(image_x).t()
+
+        # get label globally
+        # [B, L1, B, L2, W]
+        labels_per_img = F.one_hot(
+            torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * dist.get_rank(),
+            num_classes=dist.get_world_size()).to(image_x.dtype)
+        labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(
+            torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')
+        # [BxL1, WxBxL2]
+        labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')
+        # [B, L2, B, L1, W]
+        labels_per_text = F.one_hot(
+            torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * dist.get_rank(),
+            num_classes=dist.get_world_size()).to(text_x.dtype)
+        labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(
+            torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')
+        # [BxL2, WxBxL1]
+        labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')
+
+        loss_img = self.soft_cross_entropy(logits_per_img * logit_scale, labels_per_img)
+        loss_text = self.soft_cross_entropy(logits_per_text * logit_scale, labels_per_text)
+
+        loss = 0.5 * (loss_img + loss_text)
+
+        return loss
+
+    def encode_image(self, image, *, return_feat=False, as_dict=False):
+        outs = Result(as_dict)
+        img_outs = self.img_encoder(image, return_feat=return_feat, as_dict=True)
+        outs.append(self.img_projector(img_outs['x']), 'image_x')
+        if return_feat:
+            outs.append(self.img_projector(img_outs['feat']), 'image_feat')
+        return outs.as_return()
+
+    def encode_text(self, text, *, as_dict=False):
+        assert text.ndim in [2, 3], text.ndim
+        squeeze_dim = False
+        num_text = 1
+        if text.ndim == 3:
+            num_text = text.shape[1]
+            text = rearrange(text, 'b n l -> (b n) l', n=num_text)
+            squeeze_dim = True
+
+        outs = Result(as_dict=as_dict)
+        # [B, C]
+        x = self.text_encoder(text)
+        text_x = self.text_projector(x)
+        outs.append(text_x, 'text_x')
+        if squeeze_dim:
+            text_x = rearrange(text_x, '(b n) c -> b n c', n=num_text)
+            text_multi_label_x = text_x[:, 1:]
+            text_x = text_x[:, 0]
+            outs.update(text_x=text_x, text_multi_label_x=text_multi_label_x)
+
+        return outs.as_return()
+
+    def forward_train(self, image, text):
+        image_outs = self.encode_image(image, as_dict=True)
+        # [B, C]
+        image_x = image_outs['image_x']
+
+        text_outs = self.encode_text(text, as_dict=True)
+        # [B, C]
+        text_x = text_outs['text_x']
+
+        losses = self.loss(image_x, text_x)
+
+        losses_dict = dict(loss=losses)
+        if self.with_multi_label:
+            image_multi_label_x = image_x.unsqueeze(1)
+            text_multi_label_x = text_outs['text_multi_label_x']
+            losses_dict['multi_label_loss'] = self.multi_label_loss(image_multi_label_x,
+                                                                    text_multi_label_x) * self.multi_label_loss_weight
+
+        return losses_dict
+
+    def forward_test(self, image, text):
+        return self.zero_shot_pred(image, text)
+
+    def forward(self, image, text):
+        if self.training:
+            return self.forward_train(image, text)
+        else:
+            return self.forward_test(image, text)
+
+    @torch.no_grad()
+    def build_text_embedding(self, text):
+        """
+
+        Args:
+            text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH]
+
+        Returns:
+
+        """
+        text = text.to(next(self.parameters()).device)
+        num_classes, num_templates = text.shape[:2]
+        text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates)
+        text_tokens = self.encode_text(text)
+        # [N, T, C]
+        text_tokens = rearrange(text_tokens, '(n t) c -> n t c', n=num_classes, t=num_templates)
+        # [N, C]
+        text_tokens = text_tokens.mean(dim=1)
+        text_tokens = F.normalize(text_tokens, dim=-1)
+
+        return text_tokens
+
+    @torch.no_grad()
+    def zero_shot_pred(self, image, text):
+        # [B, C]
+        image_features = self.encode_image(image)
+        image_features = F.normalize(image_features, dim=-1)
+
+        # cosine similarity as logits
+        logits_per_image = image_features @ text.t()
+
+        return logits_per_image

+ 117 - 0
models/transformer.py

@@ -0,0 +1,117 @@
+# -------------------------------------------------------------------------
+# MIT License
+#
+# Copyright (c) 2021 OpenAI
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import torch
+import torch.utils.checkpoint as checkpoint
+from torch import nn
+
+from .builder import MODELS
+from .misc import Result
+from .utils import ResidualAttentionBlock
+
+
+class Transformer(nn.Module):
+
+    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_checkpoint=False):
+        super().__init__()
+        self.width = width
+        self.layers = layers
+        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+        proj_std = (self.width**-0.5) * ((2 * self.layers)**-0.5)
+        attn_std = self.width**-0.5
+        fc_std = (2 * self.width)**-0.5
+        for block in self.resblocks:
+            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+        self.use_checkpoint = use_checkpoint
+
+    def forward(self, x: torch.Tensor):
+        for resblock in self.resblocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(resblock, x)
+            else:
+                x = resblock(x)
+        return x
+
+
+@MODELS.register_module()
+class TextTransformer(nn.Module):
+
+    def __init__(
+        self,
+        context_length: int,
+        width: int,
+        layers: int,
+        vocab_size,
+        use_checkpoint=False,
+    ):
+
+        super().__init__()
+        heads = width // 64
+        self.context_length = context_length
+        self.width = width
+        self.transformer = Transformer(
+            width=width,
+            layers=layers,
+            heads=heads,
+            attn_mask=self.build_attention_mask(),
+            use_checkpoint=use_checkpoint)
+
+        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
+        self.ln_final = nn.LayerNorm(width)
+        self.token_embedding = nn.Embedding(vocab_size, width)
+        nn.init.normal_(self.token_embedding.weight, std=0.02)
+
+        # initialization
+        nn.init.normal_(self.positional_embedding, std=0.01)
+
+    def build_attention_mask(self):
+        # lazily create causal attention mask, with full attention between the vision tokens
+        # pytorch uses additive attention mask; fill with -inf
+        mask = torch.empty(self.context_length, self.context_length)
+        mask.fill_(float('-inf'))
+        mask.triu_(1)  # zero out the lower diagonal
+        return mask
+
+    def forward(self, text, *, as_dict=False):
+        x = self.token_embedding(text)
+        outs = Result(as_dict=as_dict)
+        x = x + self.positional_embedding
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.ln_final(x)
+
+        # x.shape = [batch_size, n_ctx, transformer.width]
+        # take features from the eot embedding (eot_token is the highest number in each sequence)
+        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
+
+        outs.append(x, name='x')
+
+        return outs.as_return()

+ 59 - 0
models/utils.py

@@ -0,0 +1,59 @@
+# -------------------------------------------------------------------------
+# MIT License
+#
+# Copyright (c) 2021 OpenAI
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from collections import OrderedDict
+
+import torch
+from torch import nn
+
+
+class QuickGELU(nn.Module):
+
+    def forward(self, x: torch.Tensor):
+        return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+
+    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+        super().__init__()
+
+        self.attn = nn.MultiheadAttention(d_model, n_head)
+        self.ln_1 = nn.LayerNorm(d_model)
+        self.mlp = nn.Sequential(
+            OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), ('gelu', QuickGELU()),
+                         ('c_proj', nn.Linear(d_model * 4, d_model))]))
+        self.ln_2 = nn.LayerNorm(d_model)
+        self.attn_mask = attn_mask
+
+    def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor):
+        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=key_padding_mask)[0]
+
+    def forward(self, x: torch.Tensor, key_padding_mask=None):
+        x = x + self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)
+        x = x + self.mlp(self.ln_2(x))
+        return x

binární
segmentation/.DS_Store


binární
segmentation/configs/.DS_Store


binární
segmentation/configs/_base_/.DS_Store


+ 15 - 0
segmentation/configs/_base_/custom_import.py

@@ -0,0 +1,15 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+custom_imports = dict(
+    imports=['segmentation.datasets.coco_object', 'segmentation.datasets.pascal_voc'], allow_failed_imports=False)

+ 44 - 0
segmentation/configs/_base_/datasets/coco.py

@@ -0,0 +1,44 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+_base_ = ['../custom_import.py']
+# dataset settings
+dataset_type = 'COCOObjectDataset'
+data_root = 'local_data/coco'
+img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        # img_scale=(2048, 512),
+        img_scale=(2048, 448),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/val2017',
+        ann_dir='annotations/val2017',
+        pipeline=test_pipeline))
+
+# test_cfg = dict(bg_thresh=.95, mode='whole')
+test_cfg = dict(bg_thresh=.95, mode='slide', stride=(224, 224), crop_size=(448, 448))

+ 43 - 0
segmentation/configs/_base_/datasets/pascal_context.py

@@ -0,0 +1,43 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+_base_ = ['../custom_import.py']
+# dataset settings
+dataset_type = 'PascalContextDataset'
+data_root = 'local_data/VOCdevkit/VOC2010/'
+img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(2048, 448),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClassContext',
+        split='ImageSets/SegmentationContext/val.txt',
+        pipeline=test_pipeline))
+
+test_cfg = dict(bg_thresh=.35, mode='slide', stride=(224, 224), crop_size=(448, 448))

+ 43 - 0
segmentation/configs/_base_/datasets/pascal_voc12.py

@@ -0,0 +1,43 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+_base_ = ['../custom_import.py']
+# dataset settings
+dataset_type = 'PascalVOCDataset'
+data_root = 'local_data/VOCdevkit/VOC2012'
+img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(2048, 448),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClass',
+        split='ImageSets/Segmentation/val.txt',
+        pipeline=test_pipeline))
+
+test_cfg = dict(bg_thresh=.95, mode='slide', stride=(224, 224), crop_size=(448, 448))

+ 18 - 0
segmentation/datasets/__init__.py

@@ -0,0 +1,18 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from .coco_object import COCOObjectDataset
+from .pascal_context import PascalContextDataset
+from .pascal_voc import PascalVOCDataset
+
+__all__ = ['COCOObjectDataset', 'PascalContextDataset', 'PascalVOCDataset']

+ 48 - 0
segmentation/datasets/coco_object.py

@@ -0,0 +1,48 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from mmseg.datasets import DATASETS, CustomDataset
+
+
+@DATASETS.register_module()
+class COCOObjectDataset(CustomDataset):
+    """COCO-Object dataset.
+
+    1 bg class + first 80 classes from the COCO-Stuff dataset.
+    """
+
+    CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
+               'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
+               'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
+               'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
+               'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
+               'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
+               'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
+               'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
+               'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
+
+    PALETTE = [[0, 0, 0], [0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], [0, 64, 64], [0, 192, 224],
+               [0, 192, 192], [128, 192, 64], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 0, 64],
+               [0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
+               [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], [64, 128, 32], [0, 160, 0], [0, 0, 0],
+               [192, 128, 160], [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], [192, 128, 32],
+               [128, 96, 128], [0, 0, 128], [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],
+               [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], [192, 0, 32],
+               [128, 96, 0], [128, 0, 192], [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],
+               [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
+               [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], [128, 192, 192], [0, 0, 160],
+               [192, 160, 128], [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], [64, 160, 0],
+               [0, 64, 0], [192, 128, 224], [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0]]
+
+    def __init__(self, **kwargs):
+        super(COCOObjectDataset, self).__init__(img_suffix='.jpg', seg_map_suffix='_instanceTrainIds.png', **kwargs)

+ 26 - 0
segmentation/datasets/pascal_context.py

@@ -0,0 +1,26 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from mmseg.datasets import DATASETS
+from mmseg.datasets import PascalContextDataset as _PascalContextDataset
+
+
+@DATASETS.register_module(force=True)
+class PascalContextDataset(_PascalContextDataset):
+
+    CLASSES = ('background', 'airplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat', 'book',
+               'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
+               'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse',
+               'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', 'plant', 'road',
+               'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train', 'tree',
+               'truck', 'monitor', 'wall', 'water', 'window', 'wood')

+ 22 - 0
segmentation/datasets/pascal_voc.py

@@ -0,0 +1,22 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from mmseg.datasets import DATASETS
+from mmseg.datasets import PascalVOCDataset as _PascalVOCDataset
+
+
+@DATASETS.register_module(force=True)
+class PascalVOCDataset(_PascalVOCDataset):
+
+    CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
+               'table', 'dog', 'horse', 'motorbike', 'person', 'plant', 'sheep', 'sofa', 'train', 'monitor')

+ 20 - 0
segmentation/evaluation/__init__.py

@@ -0,0 +1,20 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from .builder import build_seg_dataloader, build_seg_dataset, build_seg_demo_pipeline, build_seg_inference
+from .group_vit_seg import GROUP_PALETTE, GroupViTSegInference
+
+__all__ = [
+    'GroupViTSegInference', 'build_seg_dataset', 'build_seg_dataloader', 'build_seg_inference',
+    'build_seg_demo_pipeline', 'GROUP_PALETTE'
+]

+ 109 - 0
segmentation/evaluation/builder.py

@@ -0,0 +1,109 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import mmcv
+from mmseg.datasets import build_dataloader, build_dataset
+from mmseg.datasets.pipelines import Compose
+from omegaconf import OmegaConf
+from utils import build_dataset_class_tokens
+
+from .group_vit_seg import GroupViTSegInference
+
+
+def build_seg_dataset(config):
+    """Build a dataset from config."""
+    cfg = mmcv.Config.fromfile(config.cfg)
+    dataset = build_dataset(cfg.data.test)
+    return dataset
+
+
+def build_seg_dataloader(dataset):
+
+    data_loader = build_dataloader(
+        dataset,
+        samples_per_gpu=1,
+        workers_per_gpu=1,
+        dist=True,
+        shuffle=False,
+        persistent_workers=True,
+        pin_memory=False)
+    return data_loader
+
+
+def build_seg_inference(model, dataset, text_transform, config):
+    cfg = mmcv.Config.fromfile(config.cfg)
+    if len(config.opts):
+        cfg.merge_from_dict(OmegaConf.to_container(OmegaConf.from_dotlist(OmegaConf.to_container(config.opts))))
+    with_bg = dataset.CLASSES[0] == 'background'
+    if with_bg:
+        classnames = dataset.CLASSES[1:]
+    else:
+        classnames = dataset.CLASSES
+    text_tokens = build_dataset_class_tokens(text_transform, config.template, classnames)
+    text_embedding = model.build_text_embedding(text_tokens)
+    kwargs = dict(with_bg=with_bg)
+    if hasattr(cfg, 'test_cfg'):
+        kwargs['test_cfg'] = cfg.test_cfg
+    seg_model = GroupViTSegInference(model, text_embedding, **kwargs)
+
+    seg_model.CLASSES = dataset.CLASSES
+    seg_model.PALETTE = dataset.PALETTE
+
+    return seg_model
+
+
+class LoadImage:
+    """A simple pipeline to load image."""
+
+    def __call__(self, results):
+        """Call function to load images into results.
+
+        Args:
+            results (dict): A result dict contains the file name
+                of the image to be read.
+
+        Returns:
+            dict: ``results`` will be returned containing loaded image.
+        """
+
+        if isinstance(results['img'], str):
+            results['filename'] = results['img']
+            results['ori_filename'] = results['img']
+        else:
+            results['filename'] = None
+            results['ori_filename'] = None
+        img = mmcv.imread(results['img'])
+        results['img'] = img
+        results['img_shape'] = img.shape
+        results['ori_shape'] = img.shape
+        return results
+
+
+def build_seg_demo_pipeline():
+    """Build a demo pipeline from config."""
+    img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+    test_pipeline = Compose([
+        LoadImage(),
+        dict(
+            type='MultiScaleFlipAug',
+            img_scale=(2048, 448),
+            flip=False,
+            transforms=[
+                dict(type='Resize', keep_ratio=True),
+                dict(type='RandomFlip'),
+                dict(type='Normalize', **img_norm_cfg),
+                dict(type='ImageToTensor', keys=['img']),
+                dict(type='Collect', keys=['img']),
+            ])
+    ])
+    return test_pipeline

+ 209 - 0
segmentation/evaluation/group_palette.txt

@@ -0,0 +1,209 @@
+128 64 128
+244 35 232
+70 70 70
+102 102 156
+190 153 153
+153 153 153
+250 170 30
+220 220 0
+107 142 35
+152 251 152
+70 130 180
+220 20 60
+255 0 0
+0 0 142
+0 0 70
+0 60 100
+0 80 100
+0 0 230
+119 11 32
+128 0 0
+0 128 0
+128 128 0
+0 0 128
+128 0 128
+0 128 128
+128 128 128
+64 0 0
+192 0 0
+64 128 0
+192 128 0
+64 0 128
+192 0 128
+64 128 128
+192 128 128
+0 64 0
+128 64 0
+0 192 0
+128 192 0
+0 64 128
+120 120 120
+180 120 120
+6 230 230
+80 50 50
+4 200 3
+120 120 80
+140 140 140
+204 5 255
+230 230 230
+4 250 7
+224 5 255
+235 255 7
+150 5 61
+120 120 70
+8 255 51
+255 6 82
+143 255 140
+204 255 4
+255 51 7
+204 70 3
+0 102 200
+61 230 250
+255 6 51
+11 102 255
+255 7 71
+128 0 0
+0 128 0
+128 128 0
+0 0 128
+128 0 128
+0 128 128
+128 128 128
+64 0 0
+192 0 0
+64 128 0
+192 128 0
+64 0 128
+192 0 128
+64 128 128
+192 128 128
+0 64 0
+128 64 0
+0 192 0
+128 192 0
+0 64 128
+255 9 224
+9 7 230
+220 220 220
+255 9 92
+112 9 255
+8 255 214
+7 255 224
+255 184 6
+10 255 71
+255 41 10
+7 255 255
+224 255 8
+102 8 255
+255 61 6
+255 194 7
+255 122 8
+0 255 20
+255 8 41
+255 5 153
+6 51 255
+235 12 255
+160 150 20
+0 163 255
+140 140 140
+250 10 15
+20 255 0
+31 255 0
+255 31 0
+255 224 0
+153 255 0
+0 0 255
+255 71 0
+0 235 255
+0 173 255
+31 0 255
+11 200 200
+255 82 0
+0 255 245
+0 61 255
+0 255 112
+0 255 133
+255 0 0
+255 163 0
+255 102 0
+194 255 0
+0 143 255
+51 255 0
+0 82 255
+0 255 41
+0 255 173
+10 0 255
+173 255 0
+0 255 153
+255 92 0
+255 0 255
+255 0 245
+255 0 102
+255 173 0
+255 0 20
+255 184 184
+0 31 255
+0 255 61
+0 71 255
+255 0 204
+0 255 194
+0 255 82
+0 10 255
+0 112 255
+51 0 255
+0 194 255
+0 122 255
+0 255 163
+255 153 0
+0 255 10
+255 112 0
+143 255 0
+82 0 255
+163 255 0
+255 235 0
+8 184 170
+133 0 255
+0 255 92
+184 0 255
+255 0 31
+0 184 255
+0 214 255
+255 0 112
+92 255 0
+0 224 255
+112 224 255
+70 184 160
+163 0 255
+153 0 255
+71 255 0
+255 0 163
+255 204 0
+255 0 143
+0 255 235
+133 255 0
+255 0 235
+245 0 255
+255 0 122
+255 245 0
+10 190 212
+214 255 0
+0 204 255
+20 0 255
+255 255 0
+0 153 255
+0 41 255
+0 255 204
+41 0 255
+41 255 0
+173 0 255
+0 245 255
+71 0 255
+122 0 255
+0 255 184
+0 92 255
+184 255 0
+0 133 255
+255 214 0
+25 194 194
+102 255 0
+92 0 255

+ 370 - 0
segmentation/evaluation/group_vit_seg.py

@@ -0,0 +1,370 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import os.path as osp
+
+import matplotlib.pyplot as plt
+import mmcv
+import numpy as np
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from mmseg.models import EncoderDecoder
+from PIL import Image
+from utils import get_logger
+
+GROUP_PALETTE = np.loadtxt(osp.join(osp.dirname(osp.abspath(__file__)), 'group_palette.txt'), dtype=np.uint8)[:, ::-1]
+
+
+def resize_attn_map(attentions, h, w, align_corners=False):
+    """
+
+    Args:
+        attentions: shape [B, num_head, H*W, groups]
+        h:
+        w:
+
+    Returns:
+
+        attentions: shape [B, num_head, h, w, groups]
+
+
+    """
+    scale = (h * w // attentions.shape[2])**0.5
+    if h > w:
+        w_featmap = w // int(np.round(scale))
+        h_featmap = attentions.shape[2] // w_featmap
+    else:
+        h_featmap = h // int(np.round(scale))
+        w_featmap = attentions.shape[2] // h_featmap
+    assert attentions.shape[
+        2] == h_featmap * w_featmap, f'{attentions.shape[2]} = {h_featmap} x {w_featmap}, h={h}, w={w}'
+
+    bs = attentions.shape[0]
+    nh = attentions.shape[1]  # number of head
+    groups = attentions.shape[3]  # number of group token
+    # [bs, nh, h*w, groups] -> [bs*nh, groups, h, w]
+    attentions = rearrange(
+        attentions, 'bs nh (h w) c -> (bs nh) c h w', bs=bs, nh=nh, h=h_featmap, w=w_featmap, c=groups)
+    attentions = F.interpolate(attentions, size=(h, w), mode='bilinear', align_corners=align_corners)
+    #  [bs*nh, groups, h, w] -> [bs, nh, h*w, groups]
+    attentions = rearrange(attentions, '(bs nh) c h w -> bs nh h w c', bs=bs, nh=nh, h=h, w=w, c=groups)
+
+    return attentions
+
+
+def top_groups(attn_map, k):
+    """
+    Args:
+        attn_map: (B, H, W, G)
+        k: int
+
+    Return:
+        (B, H, W, k)
+    """
+
+    attn_map = attn_map.clone()
+
+    for i in range(attn_map.size(0)):
+        # [H*W, G]
+        flatten_map = rearrange(attn_map[i], 'h w g -> (h w) g')
+        kept_mat = torch.zeros(flatten_map.shape[0], device=flatten_map.device, dtype=torch.bool)
+        area_per_group = flatten_map.sum(dim=0)
+        top_group_idx = area_per_group.topk(k=k).indices.cpu().numpy().tolist()
+        for group_idx in top_group_idx:
+            kept_mat[flatten_map.argmax(dim=-1) == group_idx] = True
+        # [H, W, 2]
+        coords = torch.stack(
+            torch.meshgrid(
+                torch.arange(attn_map[i].shape[0], device=attn_map[i].device, dtype=attn_map[i].dtype),
+                torch.arange(attn_map[i].shape[1], device=attn_map[i].device, dtype=attn_map[i].dtype)),
+            dim=-1)
+        coords = rearrange(coords, 'h w c -> (h w) c')
+
+        # calculate distance between each pair of points
+        # [non_kept, kept]
+        dist_mat = torch.sum((coords[~kept_mat].unsqueeze(1) - coords[kept_mat].unsqueeze(0))**2, dim=-1)
+
+        flatten_map[~kept_mat] = flatten_map[kept_mat.nonzero(as_tuple=True)[0][dist_mat.argmin(dim=-1)]]
+
+        attn_map[i] = flatten_map.reshape_as(attn_map[i])
+
+    return attn_map
+
+
+def seg2coord(seg_map):
+    """
+    Args:
+        seg_map (np.ndarray): (H, W)
+
+    Return:
+        dict(group_id -> (x, y))
+    """
+    h, w = seg_map.shape
+    # [h ,w, 2]
+    coords = np.stack(np.meshgrid(np.arange(h), np.arange(w), indexing='ij'), axis=-1)
+    labels = np.unique(seg_map)
+    coord_map = {}
+    for label in labels:
+        coord_map[label] = coords[seg_map == label].mean(axis=0)
+    return coord_map
+
+
+class GroupViTSegInference(EncoderDecoder):
+
+    def __init__(self, model, text_embedding, with_bg, test_cfg=dict(mode='whole', bg_thresh=.95)):
+        super(EncoderDecoder, self).__init__()
+        if not isinstance(test_cfg, mmcv.Config):
+            test_cfg = mmcv.Config(test_cfg)
+        self.test_cfg = test_cfg
+        self.model = model
+        # [N, C]
+        self.register_buffer('text_embedding', text_embedding)
+        self.with_bg = with_bg
+        self.bg_thresh = test_cfg['bg_thresh']
+        if self.with_bg:
+            self.num_classes = len(text_embedding) + 1
+        else:
+            self.num_classes = len(text_embedding)
+        self.align_corners = False
+        logger = get_logger()
+        logger.info(
+            f'Building GroupViTSegInference with {self.num_classes} classes, test_cfg={test_cfg}, with_bg={with_bg}')
+
+    def forward_train(self, img, img_metas, gt_semantic_seg):
+        raise NotImplementedError
+
+    def get_attn_maps(self, img, return_onehot=False, rescale=False):
+        """
+        Args:
+            img: [B, C, H, W]
+
+        Returns:
+            attn_maps: list[Tensor], attention map of shape [B, H, W, groups]
+        """
+        results = self.model.img_encoder(img, return_attn=True, as_dict=True)
+
+        attn_maps = []
+        with torch.no_grad():
+            prev_attn_masks = None
+            for idx, attn_dict in enumerate(results['attn_dicts']):
+                if attn_dict is None:
+                    assert idx == len(results['attn_dicts']) - 1, 'only last layer can be None'
+                    continue
+                # [B, G, HxW]
+                # B: batch size (1), nH: number of heads, G: number of group token
+                attn_masks = attn_dict['soft']
+                # [B, nH, G, HxW] -> [B, nH, HxW, G]
+                attn_masks = rearrange(attn_masks, 'b h g n -> b h n g')
+                if prev_attn_masks is None:
+                    prev_attn_masks = attn_masks
+                else:
+                    prev_attn_masks = prev_attn_masks @ attn_masks
+                # [B, nH, HxW, G] -> [B, nH, H, W, G]
+                attn_maps.append(resize_attn_map(prev_attn_masks, *img.shape[-2:]))
+
+        for i in range(len(attn_maps)):
+            attn_map = attn_maps[i]
+            # [B, nh, H, W, G]
+            assert attn_map.shape[1] == 1
+            # [B, H, W, G]
+            attn_map = attn_map.squeeze(1)
+
+            if rescale:
+                attn_map = rearrange(attn_map, 'b h w g -> b g h w')
+                attn_map = F.interpolate(
+                    attn_map, size=img.shape[2:], mode='bilinear', align_corners=self.align_corners)
+                attn_map = rearrange(attn_map, 'b g h w -> b h w g')
+
+            if return_onehot:
+                # [B, H, W, G]
+                attn_map = F.one_hot(attn_map.argmax(dim=-1), num_classes=attn_map.shape[-1]).to(dtype=attn_map.dtype)
+
+            attn_maps[i] = attn_map
+
+        return attn_maps
+
+    def encode_decode(self, img, img_metas):
+        """Encode images with backbone and decode into a semantic segmentation
+        map of the same size as input."""
+
+        assert img.shape[0] == 1, 'batch size must be 1'
+
+        # [B, C, H, W], get the last one only
+        attn_map = self.get_attn_maps(img, rescale=True)[-1]
+        # [H, W, G], select batch idx 0
+        attn_map = attn_map[0]
+
+        img_outs = self.model.encode_image(img, return_feat=True, as_dict=True)
+        # [B, L, C] -> [L, C]
+        grouped_img_tokens = img_outs['image_feat'].squeeze(0)
+        img_avg_feat = img_outs['image_x']
+        # [G, C]
+        grouped_img_tokens = F.normalize(grouped_img_tokens, dim=-1)
+        img_avg_feat = F.normalize(img_avg_feat, dim=-1)
+
+        # [H, W, G]
+        onehot_attn_map = F.one_hot(attn_map.argmax(dim=-1), num_classes=attn_map.shape[-1]).to(dtype=attn_map.dtype)
+
+        num_fg_classes = self.text_embedding.shape[0]
+        class_offset = 1 if self.with_bg else 0
+        text_tokens = self.text_embedding
+        num_classes = num_fg_classes + class_offset
+
+        logit_scale = torch.clamp(self.model.logit_scale.exp(), max=100)
+        # [G, N]
+        group_affinity_mat = (grouped_img_tokens @ text_tokens.T) * logit_scale
+        pre_group_affinity_mat = F.softmax(group_affinity_mat, dim=-1)
+
+        avg_affinity_mat = (img_avg_feat @ text_tokens.T) * logit_scale
+        avg_affinity_mat = F.softmax(avg_affinity_mat, dim=-1)
+        affinity_mask = torch.zeros_like(avg_affinity_mat)
+        avg_affinity_topk = avg_affinity_mat.topk(dim=-1, k=min(5, num_fg_classes))
+        affinity_mask.scatter_add_(
+            dim=-1, index=avg_affinity_topk.indices, src=torch.ones_like(avg_affinity_topk.values))
+        group_affinity_mat.masked_fill_(~affinity_mask.bool(), float('-inf'))
+
+        group_affinity_mat = F.softmax(group_affinity_mat, dim=-1)
+
+        # TODO: check if necessary
+        group_affinity_mat *= pre_group_affinity_mat
+
+        pred_logits = torch.zeros(num_classes, *attn_map.shape[:2], device=img.device, dtype=img.dtype)
+
+        pred_logits[class_offset:] = rearrange(onehot_attn_map @ group_affinity_mat, 'h w c -> c h w')
+        if self.with_bg:
+            bg_thresh = min(self.bg_thresh, group_affinity_mat.max().item())
+            pred_logits[0, (onehot_attn_map @ group_affinity_mat).max(dim=-1).values < bg_thresh] = 1
+
+        return pred_logits.unsqueeze(0)
+
+    def blend_result(self, img, result, palette=None, out_file=None, opacity=0.5, with_bg=False):
+        img = mmcv.imread(img)
+        img = img.copy()
+        seg = result[0]
+        if palette is None:
+            palette = self.PALETTE
+        palette = np.array(palette)
+        assert palette.shape[1] == 3, palette.shape
+        assert len(palette.shape) == 2
+        assert 0 < opacity <= 1.0
+        color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
+        for label, color in enumerate(palette):
+            color_seg[seg == label, :] = color
+        # convert to BGR
+        color_seg = color_seg[..., ::-1]
+
+        if with_bg:
+            fg_mask = seg != 0
+            img[fg_mask] = img[fg_mask] * (1 - opacity) + color_seg[fg_mask] * opacity
+        else:
+            img = img * (1 - opacity) + color_seg * opacity
+        img = img.astype(np.uint8)
+
+        if out_file is not None:
+            mmcv.imwrite(img, out_file)
+
+        return img
+
+    def show_result(self, img_show, img_tensor, result, out_file, vis_mode='input'):
+
+        assert vis_mode in [
+            'input', 'pred', 'input_pred', 'all_groups', 'first_group', 'final_group', 'input_pred_label'
+        ], vis_mode
+
+        if vis_mode == 'input':
+            mmcv.imwrite(img_show, out_file)
+        elif vis_mode == 'pred':
+            output = Image.fromarray(result[0].astype(np.uint8)).convert('P')
+            output.putpalette(np.array(self.PALETTE).astype(np.uint8))
+            mmcv.mkdir_or_exist(osp.dirname(out_file))
+            output.save(out_file.replace('.jpg', '.png'))
+        elif vis_mode == 'input_pred':
+            self.blend_result(img=img_show, result=result, out_file=out_file, opacity=0.5, with_bg=self.with_bg)
+        elif vis_mode == 'input_pred_label':
+            labels = np.unique(result[0])
+            coord_map = seg2coord(result[0])
+            # reference: https://github.com/open-mmlab/mmdetection/blob/ff9bc39913cb3ff5dde79d3933add7dc2561bab7/mmdet/models/detectors/base.py#L271 # noqa
+            blended_img = self.blend_result(
+                img=img_show, result=result, out_file=None, opacity=0.5, with_bg=self.with_bg)
+            blended_img = mmcv.bgr2rgb(blended_img)
+            width, height = img_show.shape[1], img_show.shape[0]
+            EPS = 1e-2
+            fig = plt.figure(frameon=False)
+            canvas = fig.canvas
+            dpi = fig.get_dpi()
+            fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)
+
+            # remove white edges by set subplot margin
+            plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
+            ax = plt.gca()
+            ax.axis('off')
+            for i, label in enumerate(labels):
+                if self.with_bg and label == 0:
+                    continue
+                center = coord_map[label].astype(np.int32)
+                label_text = self.CLASSES[label]
+                ax.text(
+                    center[1],
+                    center[0],
+                    f'{label_text}',
+                    bbox={
+                        'facecolor': 'black',
+                        'alpha': 0.5,
+                        'pad': 0.7,
+                        'edgecolor': 'none'
+                    },
+                    color='orangered',
+                    fontsize=16,
+                    verticalalignment='top',
+                    horizontalalignment='left')
+            plt.imshow(blended_img)
+            stream, _ = canvas.print_to_buffer()
+            buffer = np.frombuffer(stream, dtype='uint8')
+            img_rgba = buffer.reshape(height, width, 4)
+            rgb, alpha = np.split(img_rgba, [3], axis=2)
+            img = rgb.astype('uint8')
+            img = mmcv.rgb2bgr(img)
+            mmcv.imwrite(img, out_file)
+            plt.close()
+
+        elif vis_mode == 'all_groups' or vis_mode == 'final_group' or vis_mode == 'first_group':
+            attn_map_list = self.get_attn_maps(img_tensor)
+            assert len(attn_map_list) in [1, 2]
+            # only show 16 groups for the first stage
+            # if len(attn_map_list) == 2:
+            #     attn_map_list[0] = top_groups(attn_map_list[0], k=16)
+
+            num_groups = [attn_map_list[layer_idx].shape[-1] for layer_idx in range(len(attn_map_list))]
+            for layer_idx, attn_map in enumerate(attn_map_list):
+                if vis_mode == 'first_group' and layer_idx != 0:
+                    continue
+                if vis_mode == 'final_group' and layer_idx != len(attn_map_list) - 1:
+                    continue
+                attn_map = rearrange(attn_map, 'b h w g -> b g h w')
+                attn_map = F.interpolate(
+                    attn_map, size=img_show.shape[:2], mode='bilinear', align_corners=self.align_corners)
+                group_result = attn_map.argmax(dim=1).cpu().numpy()
+                if vis_mode == 'all_groups':
+                    layer_out_file = out_file.replace(
+                        osp.splitext(out_file)[-1], f'_layer{layer_idx}{osp.splitext(out_file)[-1]}')
+                else:
+                    layer_out_file = out_file
+                self.blend_result(
+                    img=img_show,
+                    result=group_result,
+                    palette=GROUP_PALETTE[sum(num_groups[:layer_idx]):sum(num_groups[:layer_idx + 1])],
+                    out_file=layer_out_file,
+                    opacity=0.5)
+        else:
+            raise ValueError(f'Unknown vis_type: {vis_mode}')

+ 32 - 0
setup.cfg

@@ -0,0 +1,32 @@
+[yapf]
+based_on_style = pep8
+blank_line_before_nested_class_or_def = true
+split_before_expression_after_opening_paren = true
+column_limit = 120
+
+[isort]
+line_length = 120
+multi_line_output = 0
+known_standard_library = setuptools
+known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,packaging,prettytable,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,torch,ts
+no_lines_before = STDLIB,LOCALFOLDER
+default_section = THIRDPARTY
+
+[codespell]
+skip = *.po,*.ts,*.ipynb
+count =
+quiet-level = 3
+ignore-words-list = formating,sur,hist
+
+[flake8]
+ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811
+max-line-length = 120
+max-complexity = 18
+select = B,C,E,F,W,T4,B9
+exclude = build
+per-file-ignores =
+  **/__init__.py:F401,F403,E402
+  **/configs/**.py:F401,E402
+  configs/**.py:F401,E402
+  **/tests/config/**.py:F401,E402
+  tests/config/**.py:F401,E402

+ 24 - 0
tools/dist_launch.sh

@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+
+SCRIPT=$1
+CONFIG=$2
+GPUS=$3
+PORT=${PORT:-29500}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
+    $SCRIPT --cfg $CONFIG ${@:4}

+ 29 - 0
tools/dist_mn_launch.sh

@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+
+SCRIPT=$1
+CONFIG=$2
+NODE_RNAK=$3
+NODES=$4
+GPUS_PER_NODE=$5
+MASTER_ADDR=$6
+PORT=${PORT:-29500}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \
+python -m torch.distributed.launch --nproc_per_node=$GPUS_PER_NODE \
+  --nnodes=$NODES --node_rank=$NODE_RNAK \
+  --master_addr=$MASTER_ADDR  \
+    $SCRIPT --cfg $CONFIG ${@:7}

+ 25 - 0
utils/__init__.py

@@ -0,0 +1,25 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from .checkpoint import auto_resume_helper, load_checkpoint, save_checkpoint
+from .config import get_config
+from .logger import get_logger
+from .lr_scheduler import build_scheduler
+from .misc import build_dataset_class_tokens, data2cuda, get_batch_size, get_grad_norm, parse_losses, reduce_tensor
+from .optimizer import build_optimizer
+
+__all__ = [
+    'get_config', 'get_logger', 'build_optimizer', 'build_scheduler', 'load_checkpoint', 'save_checkpoint',
+    'auto_resume_helper', 'reduce_tensor', 'get_grad_norm', 'get_batch_size', 'data2cuda', 'parse_losses',
+    'build_dataset_class_tokens'
+]

+ 145 - 0
utils/checkpoint.py

@@ -0,0 +1,145 @@
+# -------------------------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE
+#
+# Written by Ze Liu, Zhenda Xie
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import os
+from collections import defaultdict
+
+import torch
+import torch.distributed as dist
+from mmcv.runner import CheckpointLoader
+from omegaconf import read_write
+
+from .logger import get_logger
+
+try:
+    # noinspection PyUnresolvedReferences
+    from apex import amp
+except ImportError:
+    amp = None
+
+
+def load_checkpoint(config, model, optimizer, lr_scheduler):
+    logger = get_logger()
+    logger.info(f'==============> Resuming form {config.checkpoint.resume}....................')
+    checkpoint = CheckpointLoader.load_checkpoint(config.checkpoint.resume, map_location='cpu')
+    msg = model.load_state_dict(checkpoint['model'], strict=False)
+    logger.info(msg)
+    metrics = defaultdict(float)
+    if (not config.evaluate.eval_only and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint
+            and 'epoch' in checkpoint):
+        optimizer.load_state_dict(checkpoint['optimizer'])
+        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+        with read_write(config):
+            config.train.start_epoch = checkpoint['epoch'] + 1
+        if 'amp' in checkpoint and config.train.amp_opt_level != 'O0' and checkpoint[
+                'config'].train.amp_opt_level != 'O0':
+            amp.load_state_dict(checkpoint['amp'])
+        logger.info(f"=> loaded successfully '{config.checkpoint.resume}' (epoch {checkpoint['epoch']})")
+        metrics = checkpoint['metrics']
+
+    del checkpoint
+    torch.cuda.empty_cache()
+    return metrics
+
+
+def save_checkpoint(config, epoch, model, metrics, optimizer, lr_scheduler, suffix=''):
+    save_state = {
+        'model': model.state_dict(),
+        'optimizer': optimizer.state_dict(),
+        'lr_scheduler': lr_scheduler.state_dict(),
+        'metrics': metrics,
+        'epoch': epoch,
+        'config': config
+    }
+    logger = get_logger()
+
+    for k, v in metrics.items():
+        save_state[k] = v
+
+    if config.train.amp_opt_level != 'O0':
+        save_state['amp'] = amp.state_dict()
+
+    if len(suffix) > 0 and not suffix.startswith('_'):
+        suffix = '_' + suffix
+    filename = f'ckpt_epoch_{epoch}{suffix}.pth'
+
+    save_path = os.path.join(config.output, filename)
+    logger.info(f'{save_path} saving......')
+    torch.save(save_state, save_path)
+    torch.save(save_state, os.path.join(config.output, 'checkpoint.pth'))
+    logger.info(f'{save_path} saved !!!')
+
+    if config.checkpoint.max_kept > 0:
+        if epoch >= config.checkpoint.max_kept:
+            logger.info(f'Epoch: {epoch}, greater than config.checkpoint.max_kept: {config.checkpoint.max_kept}')
+            end_clean_epoch = epoch - config.checkpoint.max_kept
+            old_path_list = []
+            for cur_clean_epoch in range(end_clean_epoch + 1):
+                old_path = os.path.join(config.output, f'ckpt_epoch_{cur_clean_epoch}{suffix}.pth')
+                if os.path.exists(old_path):
+                    logger.info(f'old checkpoint path {old_path} exits')
+                    old_path_list.append(old_path)
+            for old_path in old_path_list[:-config.checkpoint.max_kept]:
+                os.remove(old_path)
+                logger.info(f'old checkpoint path {old_path} removed!!!')
+
+
+def get_grad_norm(parameters, norm_type=2):
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = list(filter(lambda p: p.grad is not None, parameters))
+    norm_type = float(norm_type)
+    total_norm = 0
+    for p in parameters:
+        param_norm = p.grad.data.norm(norm_type)
+        total_norm += param_norm.item()**norm_type
+    total_norm = total_norm**(1. / norm_type)
+    return total_norm
+
+
+def auto_resume_helper(output_dir):
+    if os.path.exists(os.path.join(output_dir, 'checkpoint.pth')):
+        return os.path.join(output_dir, 'checkpoint.pth')
+
+    checkpoints = os.listdir(output_dir)
+    checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
+    print(f'All checkpoints founded in {output_dir}: {checkpoints}')
+    if len(checkpoints) > 0:
+        latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
+        print(f'The latest checkpoint founded: {latest_checkpoint}')
+        resume_file = latest_checkpoint
+    else:
+        resume_file = None
+    return resume_file
+
+
+def reduce_tensor(tensor):
+    rt = tensor.clone()
+    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
+    rt /= dist.get_world_size()
+    return rt

+ 77 - 0
utils/config.py

@@ -0,0 +1,77 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import os
+import os.path as osp
+
+from omegaconf import OmegaConf
+
+
+def load_config(cfg_file):
+    cfg = OmegaConf.load(cfg_file)
+    if '_base_' in cfg:
+        if isinstance(cfg._base_, str):
+            base_cfg = OmegaConf.load(osp.join(osp.dirname(cfg_file), cfg._base_))
+        else:
+            base_cfg = OmegaConf.merge(OmegaConf.load(f) for f in cfg._base_)
+        cfg = OmegaConf.merge(base_cfg, cfg)
+    return cfg
+
+
+def get_config(args):
+    cfg = load_config(args.cfg)
+    OmegaConf.set_struct(cfg, True)
+
+    if args.opts is not None:
+        cfg = OmegaConf.merge(cfg, OmegaConf.from_dotlist(args.opts))
+    if hasattr(args, 'batch_size') and args.batch_size:
+        cfg.data.batch_size = args.batch_size
+
+    if hasattr(args, 'amp_opt_level') and args.amp_opt_level:
+        cfg.train.amp_opt_level = args.amp_opt_level
+
+    if hasattr(args, 'resume') and args.resume:
+        cfg.checkpoint.resume = args.resume
+
+    if hasattr(args, 'eval') and args.eval:
+        cfg.evaluate.eval_only = args.eval
+
+    if hasattr(args, 'keep') and args.keep:
+        cfg.checkpoint.max_kept = args.keep
+
+    if not cfg.model_name:
+        cfg.model_name = osp.splitext(osp.basename(args.cfg))[0]
+
+    world_size = int(os.environ.get('WORLD_SIZE', 1))
+    cfg.model_name = cfg.model_name + f'_bs{cfg.data.batch_size}x{world_size}'
+
+    if hasattr(args, 'output') and args.output:
+        cfg.output = args.output
+    else:
+        cfg.output = osp.join('output', cfg.model_name)
+
+    if hasattr(args, 'tag') and args.tag:
+        cfg.tag = args.tag
+        cfg.output = osp.join(cfg.output, cfg.tag)
+
+    if hasattr(args, 'wandb') and args.wandb:
+        cfg.wandb = args.wandb
+
+    if hasattr(args, 'vis') and args.vis:
+        cfg.vis = args.vis
+
+    cfg.local_rank = args.local_rank
+
+    OmegaConf.set_readonly(cfg, True)
+
+    return cfg

+ 61 - 0
utils/logger.py

@@ -0,0 +1,61 @@
+# -------------------------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE
+#
+# Written by Ze Liu, Zhenda Xie
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import logging
+import os.path as osp
+
+from mmcv.utils import get_logger as get_root_logger
+from termcolor import colored
+
+logger_name = None
+
+
+def get_logger(cfg=None, log_level=logging.INFO):
+    global logger_name
+    if cfg is None:
+        return get_root_logger(logger_name)
+
+    # creating logger
+    name = cfg.model_name
+    output = cfg.output
+    logger_name = name
+
+    logger = get_root_logger(name, osp.join(output, 'log.txt'), log_level=log_level, file_mode='a')
+
+    fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
+    color_fmt = colored('[%(asctime)s %(name)s]', 'green') \
+        + colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
+
+    for handler in logger.handlers:
+        if isinstance(handler, logging.StreamHandler):
+            handler.setFormatter(logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
+
+        if isinstance(handler, logging.FileHandler):
+            handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
+
+    return logger

+ 36 - 0
utils/lr_scheduler.py

@@ -0,0 +1,36 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from timm.scheduler.cosine_lr import CosineLRScheduler
+
+
+def build_scheduler(config, optimizer, n_iter_per_epoch):
+    num_steps = int(config.epochs * n_iter_per_epoch)
+    warmup_steps = int(config.warmup_epochs * n_iter_per_epoch)
+
+    lr_scheduler = None
+    if config.lr_scheduler.name == 'cosine':
+        lr_scheduler = CosineLRScheduler(
+            optimizer,
+            t_initial=num_steps,
+            t_mul=1.,
+            lr_min=config.min_lr,
+            warmup_lr_init=config.warmup_lr,
+            warmup_t=warmup_steps,
+            cycle_limit=1,
+            t_in_epochs=False,
+        )
+    else:
+        raise NotImplementedError(f'lr scheduler {config.lr_scheduler.name} not implemented')
+
+    return lr_scheduler

+ 94 - 0
utils/misc.py

@@ -0,0 +1,94 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
+# property and proprietary rights in and to this software, related
+# documentation and any modifications thereto.  Any use, reproduction,
+# disclosure or distribution of this software and related documentation
+# without an express license agreement from NVIDIA CORPORATION is strictly
+# prohibited.
+#
+# Written by Jiarui Xu
+# -------------------------------------------------------------------------
+
+import collections.abc
+from collections import OrderedDict
+
+import torch
+import torch.distributed as dist
+from datasets import template_meta
+
+
+def reduce_tensor(tensor):
+    rt = tensor.clone()
+    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
+    rt /= dist.get_world_size()
+    return rt
+
+
+def get_grad_norm(parameters, norm_type=2):
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = list(filter(lambda p: p.grad is not None, parameters))
+    norm_type = float(norm_type)
+    total_norm = 0
+    for p in parameters:
+        param_norm = p.grad.data.norm(norm_type)
+        total_norm += param_norm.item()**norm_type
+    total_norm = total_norm**(1. / norm_type)
+    return total_norm
+
+
+def get_batch_size(data):
+
+    if isinstance(data, torch.Tensor):
+        return data.size(0)
+    elif isinstance(data, collections.abc.Mapping):
+        return get_batch_size(data[next(iter(data))])
+    elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
+        # check to make sure that the elements in batch have consistent size
+        it = iter(data)
+        return get_batch_size(next(it))
+
+    raise TypeError
+
+
+def data2cuda(data):
+
+    if isinstance(data, torch.Tensor):
+        batch = data.cuda(non_blocking=True)
+        return batch
+    elif isinstance(data, collections.abc.Mapping):
+        return {key: data2cuda(data[key]) for key in data}
+    elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
+        return [data2cuda(d) for d in data]
+    else:
+        raise TypeError
+
+
+def parse_losses(losses):
+    log_vars = OrderedDict()
+    for loss_name, loss_value in losses.items():
+        if isinstance(loss_value, torch.Tensor):
+            log_vars[loss_name] = loss_value.mean()
+        elif isinstance(loss_value, list):
+            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+        else:
+            raise TypeError(f'{loss_name} is not a tensor or list of tensors')
+
+    loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
+
+    return loss, log_vars
+
+
+def build_dataset_class_tokens(text_transform, template_set, classnames):
+
+    tokens = []
+    templates = template_meta[template_set]
+    for classname in classnames:
+        # format with class
+        tokens.append(torch.stack([text_transform(template.format(classname)) for template in templates]))
+    # [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
+    tokens = torch.stack(tokens)
+
+    return tokens

+ 72 - 0
utils/optimizer.py

@@ -0,0 +1,72 @@
+# -------------------------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE
+#
+# Written by Ze Liu, Zhenda Xie
+# Modified by Jiarui Xu
+# -------------------------------------------------------------------------
+
+from torch import optim as optim
+
+
+def build_optimizer(config, model):
+    """Build optimizer, set weight decay of normalization to 0 by default."""
+    parameters = set_weight_decay(model, {}, {})
+
+    opt_name = config.optimizer.name
+    optimizer = None
+    if opt_name == 'adamw':
+        optimizer = optim.AdamW(
+            parameters,
+            eps=config.optimizer.eps,
+            betas=config.optimizer.betas,
+            lr=config.base_lr,
+            weight_decay=config.weight_decay)
+    else:
+        raise ValueError(f'Unsupported optimizer: {opt_name}')
+
+    return optimizer
+
+
+def set_weight_decay(model, skip_list=(), skip_keywords=()):
+    has_decay = []
+    no_decay = []
+
+    for name, param in model.named_parameters():
+        if not param.requires_grad:
+            continue  # frozen weights
+        if len(param.shape) == 1 or name.endswith('.bias') or (name in skip_list) or \
+                check_keywords_in_name(name, skip_keywords):
+            no_decay.append(param)
+            # print(f"{name} has no weight decay")
+        else:
+            has_decay.append(param)
+    return [{'params': has_decay}, {'params': no_decay, 'weight_decay': 0.}]
+
+
+def check_keywords_in_name(name, keywords=()):
+    isin = False
+    for keyword in keywords:
+        if keyword in name:
+            isin = True
+    return isin