Просмотр исходного кода

feat(dataset): 添加 CUHK-PEDES 和 ImageNet 数据集处理脚本

- 新增 CUHK-PEDES 数据集处理脚本,用于将数据集转换为 webdataset 格式
- 新增 ImageNet 数据集分类脚本,用于将验证集图片按类别分类
- 新增 ImageNet 数据集提取脚本,用于解压训练集和验证集数据
Yijun Fu 1 месяц назад
Сommit
0728f2170a
3 измененных файлов с 422 добавлено и 0 удалено
  1. 348 0
      cuhkpedes/CUHK-PEDES2webdataset.ipynb
  2. 33 0
      imagenet/classify_val_images.py
  3. 41 0
      imagenet/extract_imagenet.py

+ 348 - 0
cuhkpedes/CUHK-PEDES2webdataset.ipynb

@@ -0,0 +1,348 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/mnt/vos-s9gjtkm2/reid/miniconda3/envs/groupvit/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n"
+     ]
+    }
+   ],
+   "source": [
+    "import pandas as pd\n",
+    "import os\n",
+    "import json\n",
+    "import webdataset as wds\n",
+    "import re\n",
+    "import random"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "/mnt/vos-s9gjtkm2/reid/dataset/cross_reid\n"
+     ]
+    }
+   ],
+   "source": [
+    "current_path = os.getcwd()\n",
+    "print(current_path)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "CUHK_PEDES_path = os.path.join(current_path, 'CUHK-PEDES')\n",
+    "annotation_path = os.path.join(CUHK_PEDES_path, 'processed_data')\n",
+    "image_path = os.path.join(CUHK_PEDES_path, 'imgs')\n",
+    "train_json_path = os.path.join(annotation_path, 'train.json')\n",
+    "val_json_path = os.path.join(annotation_path, 'val.json')\n",
+    "test_json_path = os.path.join(annotation_path, 'test.json')\n",
+    "base = os.path.join(current_path, 'CUHK-PEDES_shards')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "  split                                           captions  \\\n",
+      "0   val  [The man has short, dark hair and wears khaki ...   \n",
+      "1   val  [A man with a gray hoodie, book bag, and khaki...   \n",
+      "2   val  [The man is wearing a grey hooded sweater, bro...   \n",
+      "3   val  [Man wearing a grey jacket, brown pants and bl...   \n",
+      "4   val  [The woman is wearing a floral printed shirt w...   \n",
+      "\n",
+      "                    file_path  \\\n",
+      "0          CUHK01/0107002.png   \n",
+      "1          CUHK01/0107004.png   \n",
+      "2          CUHK01/0107001.png   \n",
+      "3          CUHK01/0107003.png   \n",
+      "4  test_query/p5969_s7727.jpg   \n",
+      "\n",
+      "                                    processed_tokens     id  \n",
+      "0  [[the, man, has, short, dark, hair, and, wears...  11004  \n",
+      "1  [[a, man, with, a, gray, hoodie, book, bag, an...  11004  \n",
+      "2  [[the, man, is, wearing, a, grey, hooded, swea...  11004  \n",
+      "3  [[man, wearing, a, grey, jacket, brown, pants,...  11004  \n",
+      "4  [[the, woman, is, wearing, a, floral, printed,...  11005  \n",
+      "3078\n"
+     ]
+    }
+   ],
+   "source": [
+    "with open(val_json_path, 'r') as file:\n",
+    "    data = json.load(file)\n",
+    "\n",
+    "train_json = pd.DataFrame(data)\n",
+    "print(train_json.head())\n",
+    "print(train_json.shape[0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# 定义一个函数来替换人称代词为对应的名词,并将名词替换为 {id}\n",
+    "def replace_pronouns_and_nouns(sentence, id):\n",
+    "    replacements = {\n",
+    "        r'\\bmale\\b': f'male_{id}',\n",
+    "        r'\\bman\\b': f'male_{id}',\n",
+    "        r'\\bmans\\b': f'male_{id}',\n",
+    "        r'\\bhe\\b': f'male_{id}',\n",
+    "        r'\\bboy\\b': f'male_{id}',\n",
+    "        r'\\bgentleman\\b': f'male_{id}',\n",
+    "        r'\\bguy\\b':f'male_{id}',\n",
+    "        r'\\bfemale\\b': f'female_{id}',\n",
+    "        r'\\bwoman\\b': f'female_{id}',\n",
+    "        r'\\bwomen\\b': f'female_{id}',\n",
+    "        r'\\bshe\\b': f'female_{id}',\n",
+    "        r'\\bgirl\\b': f'female_{id}',\n",
+    "        r'\\bgirls\\b': f'female_{id}',\n",
+    "        r'\\blady\\b': f'female_{id}',\n",
+    "        r'\\bcheerleader\\b': f'female_{id}',\n",
+    "        r'\\bperson\\b':f'person_{id}',\n",
+    "        r'\\bi\\b':f'person_{id}',\n",
+    "        r'\\byou\\b':f'person_{id}',\n",
+    "        r'\\bbaby\\b':f'person_{id}',\n",
+    "        r'\\bchild\\b':f'person_{id}',\n",
+    "        r'\\badult\\b':f'person_{id}',\n",
+    "        r'\\bpedestrian\\b':f'person_{id}',\n",
+    "        r'\\bunknown gender\\b':f'person_{id}',\n",
+    "        r'\\bunknown subject\\b':f'person_{id}',\n",
+    "        r'\\bwe\\b': f'people_{id}',\n",
+    "        r'\\bthey\\b': f'people_{id}',\n",
+    "        r'\\bpeople\\b': f'people_{id}'\n",
+    "    }\n",
+    "    for pattern, replacement in replacements.items():\n",
+    "        sentence = re.sub(pattern, replacement, sentence)\n",
+    "    return sentence"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# 创建一个具有预定义列的空 DataFrame\n",
+    "columns = ['file_path', 'caption', 'id']\n",
+    "preprocess_df = pd.DataFrame(columns=columns)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# 遍历 data 中的每一条记录\n",
+    "for index, row in train_json.iterrows():\n",
+    "    id = row['id']\n",
+    "    file_path = row['file_path']\n",
+    "    caption = row['captions']\n",
+    "\n",
+    "    # 确保 captions 是一个字符串并转换为小写\n",
+    "    if isinstance(caption, list):\n",
+    "        caption = ' '.join(caption).lower()\n",
+    "    else:\n",
+    "        caption = caption.lower()\n",
+    "\n",
+    "    # 替换人称代词和名词\n",
+    "    replaced_caption = replace_pronouns_and_nouns(caption, id)\n",
+    "\n",
+    "    # 提取 [人物_{id}] 和匹配 TOP_CLASSES_1 中的实体\n",
+    "    entities = []\n",
+    "\n",
+    "    # 提取所有替换后的人称代词和名词\n",
+    "    person_patterns = [\n",
+    "        re.compile(r'\\bmale_\\d+\\b'),\n",
+    "        re.compile(r'\\bfemale_\\d+\\b'),\n",
+    "        re.compile(r'\\bperson_\\d+\\b'),\n",
+    "        re.compile(r'\\bpeople_\\d+\\b')\n",
+    "    ]\n",
+    "    \n",
+    "    # 检查是否有替换后的人称代词或名词\n",
+    "    if not any(pattern.search(replaced_caption) for pattern in person_patterns):\n",
+    "        print(f\"No replacement in sentence: {id}\")\n",
+    "\n",
+    "    # 将结果添加到 preprocess_df 中\n",
+    "    new_row = pd.DataFrame({'file_path': [file_path], 'caption': [replaced_caption], 'id': [id]})\n",
+    "    preprocess_df = pd.concat([preprocess_df, new_row], ignore_index=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "                    file_path  \\\n",
+      "0          CUHK01/0107002.png   \n",
+      "1          CUHK01/0107004.png   \n",
+      "2          CUHK01/0107001.png   \n",
+      "3          CUHK01/0107003.png   \n",
+      "4  test_query/p5969_s7727.jpg   \n",
+      "\n",
+      "                                             caption     id  \n",
+      "0  the male_11004 has short, dark hair and wears ...  11004  \n",
+      "1  a male_11004 with a gray hoodie, book bag, and...  11004  \n",
+      "2  the male_11004 is wearing a grey hooded sweate...  11004  \n",
+      "3  male_11004 wearing a grey jacket, brown pants ...  11004  \n",
+      "4  the female_11005 is wearing a floral printed s...  11005  \n",
+      "the male_11004 has short, dark hair and wears khaki pants with an oversized grey hoodie. his black backpack hangs from one shoulder. a male_11004 wearing a gray, hooded jacket, a pair of wrinkled brown pants, a gray backpack and a pair of dark colored shoes.\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(preprocess_df.head())\n",
+    "print(preprocess_df.at[0, 'caption'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "nimages = preprocess_df.shape[0]\n",
+    "indexes = list(range(nimages))\n",
+    "random.shuffle(indexes)\n",
+    "\n",
+    "# pattern = os.path.join(base, f\"cuhkpedes-train-%06d.tar\")\n",
+    "pattern = os.path.join(base, f\"cuhkpedes-val-%06d.tar\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000000.tar 0 0.0 GB 0\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000001.tar 133 0.0 GB 133\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000002.tar 148 0.0 GB 281\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000003.tar 139 0.0 GB 420\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000004.tar 134 0.0 GB 554\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000005.tar 123 0.0 GB 677\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000006.tar 136 0.0 GB 813\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000007.tar 126 0.0 GB 939\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000008.tar 136 0.0 GB 1075\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000009.tar 151 0.0 GB 1226\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000010.tar 146 0.0 GB 1372\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000011.tar 143 0.0 GB 1515\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000012.tar 146 0.0 GB 1661\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000013.tar 127 0.0 GB 1788\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000014.tar 145 0.0 GB 1933\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000015.tar 135 0.0 GB 2068\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000016.tar 133 0.0 GB 2201\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000017.tar 121 0.0 GB 2322\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000018.tar 120 0.0 GB 2442\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000019.tar 128 0.0 GB 2570\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000020.tar 124 0.0 GB 2694\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000021.tar 115 0.0 GB 2809\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000022.tar 138 0.0 GB 2947\n",
+      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000023.tar 128 0.0 GB 3075\n"
+     ]
+    }
+   ],
+   "source": [
+    "def readfile(fname):\n",
+    "    \"Read a binary file from disk.\"\n",
+    "    with open(fname, \"rb\") as stream:\n",
+    "        return stream.read()\n",
+    "    \n",
+    "all_keys = set()\n",
+    "\n",
+    "with wds.ShardWriter(pattern, maxsize=1000000, maxcount=1000000) as sink:\n",
+    "    for i in indexes:\n",
+    "\n",
+    "        # Internal information from the ImageNet dataset\n",
+    "        # instance: the file name and the numerical class.\n",
+    "        fname = preprocess_df.at[i, 'file_path']\n",
+    "        caption = preprocess_df.at[i, 'caption']\n",
+    "        id = preprocess_df.at[i, 'id']\n",
+    "        fname = os.path.join(image_path, fname)\n",
+    "\n",
+    "        # Read the JPEG-compressed image file contents.\n",
+    "        image = readfile(fname)\n",
+    "\n",
+    "        # Construct a uniqu keye from the filename.\n",
+    "        base_dir = os.path.dirname(fname)\n",
+    "        dir_name = os.path.basename(base_dir)\n",
+    "        key = os.path.splitext(os.path.basename(fname))[0]\n",
+    "        key = f\"{dir_name}_{key}\"\n",
+    "\n",
+    "        # Useful check.\n",
+    "        assert key not in all_keys, f\"Conflict detected: Key '{key}' already exists.\"\n",
+    "        all_keys.add(key)\n",
+    "\n",
+    "        # Construct the cls field with the new format.\n",
+    "        cls = f\"4 4 1\\n# male_{id} female_{id} person_{id} people_{id}\\n0 1 2 3\"        \n",
+    "\n",
+    "        # Construct a sample.\n",
+    "        xkey = key if True else \"%07d\" % i\n",
+    "        sample = {\"__key__\": xkey, \"jpg\": image, \"cls\": cls}\n",
+    "        # sample = {\"__key__\": xkey, \"jpg\": image, \"txt\": caption}\n",
+    "\n",
+    "        # Write the sample to the sharded tar archives.\n",
+    "        sink.write(sample)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "groupvit",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.12"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

+ 33 - 0
imagenet/classify_val_images.py

@@ -0,0 +1,33 @@
+import os
+import shutil
+import torch
+
+# 设置根目录路径
+root_directory = "/mnt/vos-s9gjtkm2/reid/dataset/imagenet"
+val_directory = os.path.join(root_directory, "val")
+val_images_directory = val_directory
+
+# 读取验证集的 ground truth 文件
+ground_truth_file = os.path.join(root_directory, "ILSVRC2012_devkit_t12", "data", "ILSVRC2012_validation_ground_truth.txt")
+with open(ground_truth_file, "r") as f:
+    val_labels = f.readlines()
+
+# 读取 meta.bin 文件以获取类别信息
+meta_file = os.path.join(root_directory, "meta.bin")
+wnid_to_classes, val_wnids = torch.load(meta_file)
+
+# 创建类别目录
+for wnid in wnid_to_classes.keys():
+    class_directory = os.path.join(val_directory, wnid)
+    if not os.path.exists(class_directory):
+        os.makedirs(class_directory)
+
+# 将图片按类别分类
+for i, label in enumerate(val_labels):
+    label = label.strip()
+    wnid = val_wnids[i]
+    src_file = os.path.join(val_images_directory, f"ILSVRC2012_val_{i + 1:08d}.JPEG")
+    dest_file = os.path.join(val_directory, wnid, f"ILSVRC2012_val_{i + 1:08d}.JPEG")
+    shutil.move(src_file, dest_file)
+
+print("Validation images have been classified by category.")

+ 41 - 0
imagenet/extract_imagenet.py

@@ -0,0 +1,41 @@
+import os
+import tarfile
+
+TRAIN_SRC_DIR = '/mnt/vos-s9gjtkm2/reid/dataset/ImageNet/ILSVRC2012/ILSVRC2012_img_train.tar'
+TRAIN_DEST_DIR = '/mnt/vos-s9gjtkm2/reid/dataset/imagenet/train'
+VAL_SRC_DIR = '/mnt/vos-s9gjtkm2/reid/dataset/ImageNet/ILSVRC2012/ILSVRC2012_img_val.tar'
+VAL_DEST_DIR = '/mnt/vos-s9gjtkm2/reid/dataset/imagenet/val'
+
+
+def extract_train():
+    with open(TRAIN_SRC_DIR, 'rb') as f:
+        tar = tarfile.open(fileobj=f, mode='r:')
+        for i, item in enumerate(tar):
+            cls_name = item.name.strip(".tar")
+            a = tar.extractfile(item)
+            b = tarfile.open(fileobj=a, mode="r:")
+            e_path = "{}/{}/".format(TRAIN_DEST_DIR, cls_name)
+            if not os.path.isdir(e_path):
+                os.makedirs(e_path)
+            print("#", i, "extract train dateset to >>>", e_path)
+            b.extractall(e_path)
+            # names = b.getnames()
+            # for name in names:
+            #     b.extract(name, e_path)
+
+
+def extract_val():
+    with open(VAL_SRC_DIR, 'rb') as f:
+        tar = tarfile.open(fileobj=f, mode='r:')
+        if not os.path.isdir(VAL_DEST_DIR):
+            os.makedirs(VAL_DEST_DIR)
+        print("extract val dateset to >>>", VAL_DEST_DIR)
+        names = tar.getnames()
+        for name in names:
+            tar.extract(name, VAL_DEST_DIR)
+
+
+if __name__ == '__main__':
+    extract_train()
+    extract_val()
+