Explorar o código

feat(cuhkpedes): 重构数据处理流程以支持 CUHK-PEDES 数据集

- 更新数据路径获取逻辑,使用用户主目录动态构建路径
- 优化数据加载和预处理,支持 train、test 和 val 三个数据集
- 重新定义数据写入逻辑,将图像和文本数据整合到单个 tar 文件中
- 改进键值生成策略,确保数据唯一性
Yijun Fu hai 1 mes
pai
achega
f4d38f0889
Modificáronse 1 ficheiros con 243 adicións e 186 borrados
  1. 243 186
      cuhkpedes/CUHK-PEDES2webdataset.ipynb

+ 243 - 186
cuhkpedes/CUHK-PEDES2webdataset.ipynb

@@ -9,7 +9,7 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "/mnt/vos-s9gjtkm2/reid/miniconda3/envs/groupvit/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "/root/miniconda3/envs/groupvit/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
       "  from .autonotebook import tqdm as notebook_tqdm\n"
      ]
     }
@@ -32,13 +32,14 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "/mnt/vos-s9gjtkm2/reid/dataset/cross_reid\n"
+      "/root/dataset\n"
      ]
     }
    ],
    "source": [
-    "current_path = os.getcwd()\n",
-    "print(current_path)"
+    "home_dir = os.path.expanduser('~')\n",
+    "dataset_path = os.path.join(home_dir, 'dataset')\n",
+    "print(dataset_path)"
    ]
   },
   {
@@ -47,13 +48,14 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "CUHK_PEDES_path = os.path.join(current_path, 'CUHK-PEDES')\n",
+    "CUHK_PEDES_path = os.path.join(dataset_path, 'CUHK-PEDES')\n",
     "annotation_path = os.path.join(CUHK_PEDES_path, 'processed_data')\n",
     "image_path = os.path.join(CUHK_PEDES_path, 'imgs')\n",
     "train_json_path = os.path.join(annotation_path, 'train.json')\n",
     "val_json_path = os.path.join(annotation_path, 'val.json')\n",
     "test_json_path = os.path.join(annotation_path, 'test.json')\n",
-    "base = os.path.join(current_path, 'CUHK-PEDES_shards')"
+    "reid_raw_file = os.path.join(CUHK_PEDES_path, 'reid_raw.json')\n",
+    "base = os.path.join(dataset_path, 'CUHK-PEDES_shards')"
    ]
   },
   {
@@ -65,142 +67,185 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "  split                                           captions  \\\n",
-      "0   val  [The man has short, dark hair and wears khaki ...   \n",
-      "1   val  [A man with a gray hoodie, book bag, and khaki...   \n",
-      "2   val  [The man is wearing a grey hooded sweater, bro...   \n",
-      "3   val  [Man wearing a grey jacket, brown pants and bl...   \n",
-      "4   val  [The woman is wearing a floral printed shirt w...   \n",
+      "   split                                           captions  \\\n",
+      "0  train  [A pedestrian with dark hair is wearing red an...   \n",
+      "1  train  [A man wearing a black jacket, black pants, re...   \n",
+      "2  train  [The man is wearing a black jacket, green jean...   \n",
+      "3  train  [He's wearing a black hooded sweatshirt with a...   \n",
+      "4  train  [The man is walking.  He is wearing a bright g...   \n",
       "\n",
-      "                    file_path  \\\n",
-      "0          CUHK01/0107002.png   \n",
-      "1          CUHK01/0107004.png   \n",
-      "2          CUHK01/0107001.png   \n",
-      "3          CUHK01/0107003.png   \n",
-      "4  test_query/p5969_s7727.jpg   \n",
+      "                      file_path  \\\n",
+      "0            CUHK01/0363004.png   \n",
+      "1            CUHK01/0363003.png   \n",
+      "2            CUHK01/0363001.png   \n",
+      "3            CUHK01/0363002.png   \n",
+      "4  train_query/p8130_s10935.jpg   \n",
       "\n",
-      "                                    processed_tokens     id  \n",
-      "0  [[the, man, has, short, dark, hair, and, wears...  11004  \n",
-      "1  [[a, man, with, a, gray, hoodie, book, bag, an...  11004  \n",
-      "2  [[the, man, is, wearing, a, grey, hooded, swea...  11004  \n",
-      "3  [[man, wearing, a, grey, jacket, brown, pants,...  11004  \n",
-      "4  [[the, woman, is, wearing, a, floral, printed,...  11005  \n",
-      "3078\n"
+      "                                    processed_tokens  id  \n",
+      "0  [[a, pedestrian, with, dark, hair, is, wearing...   1  \n",
+      "1  [[a, man, wearing, a, black, jacket, black, pa...   1  \n",
+      "2  [[the, man, is, wearing, a, black, jacket, gre...   1  \n",
+      "3  [[hes, wearing, a, black, hooded, sweatshirt, ...   1  \n",
+      "4  [[the, man, is, walking, he, is, wearing, a, b...   2  \n",
+      "40206\n"
      ]
     }
    ],
    "source": [
-    "with open(val_json_path, 'r') as file:\n",
-    "    data = json.load(file)\n",
-    "\n",
-    "train_json = pd.DataFrame(data)\n",
-    "print(train_json.head())\n",
-    "print(train_json.shape[0])"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# 定义一个函数来替换人称代词为对应的名词,并将名词替换为 {id}\n",
-    "def replace_pronouns_and_nouns(sentence, id):\n",
-    "    replacements = {\n",
-    "        r'\\bmale\\b': f'male_{id}',\n",
-    "        r'\\bman\\b': f'male_{id}',\n",
-    "        r'\\bmans\\b': f'male_{id}',\n",
-    "        r'\\bhe\\b': f'male_{id}',\n",
-    "        r'\\bboy\\b': f'male_{id}',\n",
-    "        r'\\bgentleman\\b': f'male_{id}',\n",
-    "        r'\\bguy\\b':f'male_{id}',\n",
-    "        r'\\bfemale\\b': f'female_{id}',\n",
-    "        r'\\bwoman\\b': f'female_{id}',\n",
-    "        r'\\bwomen\\b': f'female_{id}',\n",
-    "        r'\\bshe\\b': f'female_{id}',\n",
-    "        r'\\bgirl\\b': f'female_{id}',\n",
-    "        r'\\bgirls\\b': f'female_{id}',\n",
-    "        r'\\blady\\b': f'female_{id}',\n",
-    "        r'\\bcheerleader\\b': f'female_{id}',\n",
-    "        r'\\bperson\\b':f'person_{id}',\n",
-    "        r'\\bi\\b':f'person_{id}',\n",
-    "        r'\\byou\\b':f'person_{id}',\n",
-    "        r'\\bbaby\\b':f'person_{id}',\n",
-    "        r'\\bchild\\b':f'person_{id}',\n",
-    "        r'\\badult\\b':f'person_{id}',\n",
-    "        r'\\bpedestrian\\b':f'person_{id}',\n",
-    "        r'\\bunknown gender\\b':f'person_{id}',\n",
-    "        r'\\bunknown subject\\b':f'person_{id}',\n",
-    "        r'\\bwe\\b': f'people_{id}',\n",
-    "        r'\\bthey\\b': f'people_{id}',\n",
-    "        r'\\bpeople\\b': f'people_{id}'\n",
-    "    }\n",
-    "    for pattern, replacement in replacements.items():\n",
-    "        sentence = re.sub(pattern, replacement, sentence)\n",
-    "    return sentence"
+    "flag = \"None\"\n",
+    "if os.path.exists(val_json_path) & os.path.exists(train_json_path) & os.path.exists(test_json_path):\n",
+    "    with open(train_json_path, 'r') as file:\n",
+    "        train_json = json.load(file)\n",
+    "    with open(test_json_path, 'r') as file:\n",
+    "        test_json = json.load(file)\n",
+    "    with open(val_json_path, 'r') as file:\n",
+    "        val_json = json.load(file)\n",
+    "    train_data = pd.DataFrame(train_json)\n",
+    "    test_data = pd.DataFrame(test_json)\n",
+    "    val_data = pd.DataFrame(val_json)\n",
+    "    print(train_data.head())\n",
+    "    print(train_data.shape[0])\n",
+    "    flag = \"ttv\"\n",
+    "elif os.path.exists(reid_raw_file):\n",
+    "    with open(reid_raw_file, 'r') as file:\n",
+    "        reid_json = json.load(file)\n",
+    "    reid_data = pd.DataFrame(reid_json)\n",
+    "    print(reid_data.head())\n",
+    "    print(reid_data.shape[0])\n",
+    "    flag = \"raw\"\n",
+    "else: raise FileNotFoundError"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [],
    "source": [
     "# 创建一个具有预定义列的空 DataFrame\n",
-    "columns = ['file_path', 'caption', 'id']\n",
-    "preprocess_df = pd.DataFrame(columns=columns)"
+    "columns = ['file_path', 'captions', 'id']\n",
+    "processed_train_data = pd.DataFrame(columns=columns)\n",
+    "processed_test_data = pd.DataFrame(columns=columns)\n",
+    "processed_val_data = pd.DataFrame(columns=columns)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 18,
    "metadata": {},
    "outputs": [],
    "source": [
-    "# 遍历 data 中的每一条记录\n",
-    "for index, row in train_json.iterrows():\n",
-    "    id = row['id']\n",
-    "    file_path = row['file_path']\n",
-    "    caption = row['captions']\n",
-    "\n",
-    "    # 确保 captions 是一个字符串并转换为小写\n",
-    "    if isinstance(caption, list):\n",
-    "        caption = ' '.join(caption).lower()\n",
-    "    else:\n",
-    "        caption = caption.lower()\n",
+    "# 遍历数据集并更新 processed_data\n",
+    "if flag == 'ttv':\n",
+    "    processed_train_data = train_data[['file_path', 'captions', 'id']]\n",
+    "    # for index, row in train_data.iterrows():\n",
+    "    #     id = row['id']\n",
+    "    #     file_path = row['file_path']\n",
+    "    #     captions = row['captions']\n",
     "\n",
-    "    # 替换人称代词和名词\n",
-    "    replaced_caption = replace_pronouns_and_nouns(caption, id)\n",
+    "    #     # 确保 captions 是一个字符串并转换为小写\n",
+    "    #     if isinstance(captions, list):\n",
+    "    #         captions = ' '.join(captions).lower()\n",
+    "    #     else:\n",
+    "    #         captions = captions.lower()\n",
+    "    \n",
+    "    #     # 将结果添加到 processed_data 中\n",
+    "    #     new_row = pd.DataFrame({'file_path': [file_path], 'captions': [captions], 'id': [id]})\n",
+    "    #     processed_train_data = pd.concat([processed_train_data, new_row], ignore_index=True)\n",
+    "    processed_test_data = test_data[['file_path', 'captions', 'id']]\n",
+    "    # for index, row in test_data.iterrows():\n",
+    "    #     id = row['id']\n",
+    "    #     file_path = row['file_path']\n",
+    "    #     captions = row['captions']\n",
     "\n",
-    "    # 提取 [人物_{id}] 和匹配 TOP_CLASSES_1 中的实体\n",
-    "    entities = []\n",
+    "    #     # 确保 captions 是一个字符串并转换为小写\n",
+    "    #     if isinstance(captions, list):\n",
+    "    #         captions = ' '.join(captions).lower()\n",
+    "    #     else:\n",
+    "    #         captions = captions.lower()\n",
+    "    \n",
+    "    #     # 将结果添加到 processed_data 中\n",
+    "    #     new_row = pd.DataFrame({'file_path': [file_path], 'captions': [captions], 'id': [id]})\n",
+    "    #     processed_test_data = pd.concat([processed_test_data, new_row], ignore_index=True)\n",
+    "    processed_val_data = val_data[['file_path', 'captions', 'id']]\n",
+    "    # for index, row in val_data.iterrows():\n",
+    "    #     id = row['id']\n",
+    "    #     file_path = row['file_path']\n",
+    "    #     captions = row['captions']\n",
     "\n",
-    "    # 提取所有替换后的人称代词和名词\n",
-    "    person_patterns = [\n",
-    "        re.compile(r'\\bmale_\\d+\\b'),\n",
-    "        re.compile(r'\\bfemale_\\d+\\b'),\n",
-    "        re.compile(r'\\bperson_\\d+\\b'),\n",
-    "        re.compile(r'\\bpeople_\\d+\\b')\n",
-    "    ]\n",
+    "    #     # 确保 captions 是一个字符串并转换为小写\n",
+    "    #     if isinstance(captions, list):\n",
+    "    #         captions = ' '.join(captions).lower()\n",
+    "    #     else:\n",
+    "    #         captions = captions.lower()\n",
     "    \n",
-    "    # 检查是否有替换后的人称代词或名词\n",
-    "    if not any(pattern.search(replaced_caption) for pattern in person_patterns):\n",
-    "        print(f\"No replacement in sentence: {id}\")\n",
+    "    #     # 将结果添加到 processed_data 中\n",
+    "    #     new_row = pd.DataFrame({'file_path': [file_path], 'captions': [captions], 'id': [id]})\n",
+    "    #     processed_val_data = pd.concat([processed_val_data, new_row], ignore_index=True)\n",
+    "        \n",
+    "elif flag == 'raw':\n",
+    "    processed_train_data = reid_data.loc[reid_data['split'] == 'train']\n",
+    "    processed_test_data = reid_data.loc[reid_data['split'] == 'test']\n",
+    "    processed_val_data = reid_data.loc[reid_data['split'] == 'val']\n",
+    "    processed_train_data = processed_train_data[['file_path', 'captions', 'id']]\n",
+    "    processed_test_data = processed_test_data[['file_path', 'captions', 'id']]\n",
+    "    processed_val_data = processed_val_data[['file_path', 'captions', 'id']]\n",
+    "    # for index, row in reid_data.iterrows():\n",
+    "    #     id = row['id']\n",
+    "    #     file_path = row['file_path']\n",
+    "    #     captions = row['captions']\n",
     "\n",
-    "    # 将结果添加到 preprocess_df 中\n",
-    "    new_row = pd.DataFrame({'file_path': [file_path], 'caption': [replaced_caption], 'id': [id]})\n",
-    "    preprocess_df = pd.concat([preprocess_df, new_row], ignore_index=True)"
+    "    #     # 确保 captions 是一个字符串并转换为小写\n",
+    "    #     if isinstance(captions, list):\n",
+    "    #         captions = ' '.join(captions).lower()\n",
+    "    #     else:\n",
+    "    #         captions = captions.lower()\n",
+    "        \n",
+    "    #     new_row = pd.DataFrame({'file_path': [file_path], 'captions': [captions], 'id': [id]})\n",
+    "    #     # 将结果添加到processed_data 中\n",
+    "    #     if row['split'] == 'train':\n",
+    "    #         processed_train_data = pd.concat([processed_train_data, new_row], ignore_index=True)\n",
+    "    #     elif row['split'] == 'test':\n",
+    "    #         processed_test_data = pd.concat([processed_test_data, new_row], ignore_index=True)\n",
+    "    #     elif row['split'] == 'val':\n",
+    "    #         processed_val_data = pd.concat([processed_val_data, new_row], ignore_index=True)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 19,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
+      "                      file_path  \\\n",
+      "0            CUHK01/0363004.png   \n",
+      "1            CUHK01/0363003.png   \n",
+      "2            CUHK01/0363001.png   \n",
+      "3            CUHK01/0363002.png   \n",
+      "4  train_query/p8130_s10935.jpg   \n",
+      "\n",
+      "                                            captions  id  \n",
+      "0  a pedestrian with dark hair is wearing red and...   1  \n",
+      "1  a man wearing a black jacket, black pants, red...   1  \n",
+      "2  the man is wearing a black jacket, green jeans...   1  \n",
+      "3  he's wearing a black hooded sweatshirt with a ...   1  \n",
+      "4  the man is walking.  he is wearing a bright gr...   2  \n",
+      "                      file_path  \\\n",
+      "0  train_query/p8848_s17661.jpg   \n",
+      "1  train_query/p8848_s17662.jpg   \n",
+      "2  train_query/p8848_s17663.jpg   \n",
+      "3   train_query/p4327_s5502.jpg   \n",
+      "4   train_query/p4327_s5503.jpg   \n",
+      "\n",
+      "                                            captions     id  \n",
+      "0  a man wearing a blue and white stripe tank top...  12004  \n",
+      "1  a man wearing a white and gray stripe shirt, a...  12004  \n",
+      "2  the man is wearing green pants and a green and...  12004  \n",
+      "3  a person is carrying a black shoulder bag over...  12005  \n",
+      "4  young man with dark hair and glasses, dark and...  12005  \n",
       "                    file_path  \\\n",
       "0          CUHK01/0107002.png   \n",
       "1          CUHK01/0107004.png   \n",
@@ -208,125 +253,137 @@
       "3          CUHK01/0107003.png   \n",
       "4  test_query/p5969_s7727.jpg   \n",
       "\n",
-      "                                             caption     id  \n",
-      "0  the male_11004 has short, dark hair and wears ...  11004  \n",
-      "1  a male_11004 with a gray hoodie, book bag, and...  11004  \n",
-      "2  the male_11004 is wearing a grey hooded sweate...  11004  \n",
-      "3  male_11004 wearing a grey jacket, brown pants ...  11004  \n",
-      "4  the female_11005 is wearing a floral printed s...  11005  \n",
-      "the male_11004 has short, dark hair and wears khaki pants with an oversized grey hoodie. his black backpack hangs from one shoulder. a male_11004 wearing a gray, hooded jacket, a pair of wrinkled brown pants, a gray backpack and a pair of dark colored shoes.\n"
+      "                                            captions     id  \n",
+      "0  the man has short, dark hair and wears khaki p...  11004  \n",
+      "1  a man with a gray hoodie, book bag, and khaki ...  11004  \n",
+      "2  the man is wearing a grey hooded sweater, brow...  11004  \n",
+      "3  man wearing a grey jacket, brown pants and bla...  11004  \n",
+      "4  the woman is wearing a floral printed shirt wi...  11005  \n"
      ]
     }
    ],
    "source": [
-    "print(preprocess_df.head())\n",
-    "print(preprocess_df.at[0, 'caption'])"
+    "# 定义一个函数,将列表中的字符串转换为小写并合并\n",
+    "def process_captions(captions_list):\n",
+    "    if isinstance(captions_list, list):\n",
+    "        return ' '.join([caption.lower() for caption in captions_list])\n",
+    "    else:\n",
+    "        return captions_list.lower()\n",
+    "    \n",
+    "processed_train_data['captions'] = processed_train_data['captions'].apply(process_captions)\n",
+    "processed_test_data['captions'] = processed_test_data['captions'].apply(process_captions)\n",
+    "processed_val_data['captions'] = processed_val_data['captions'].apply(process_captions)\n",
+    "\n",
+    "processed_train_data.reset_index(drop=True, inplace=True)\n",
+    "processed_test_data.reset_index(drop=True, inplace=True)\n",
+    "processed_val_data.reset_index(drop=True, inplace=True)\n",
+    "\n",
+    "print(processed_train_data.head())\n",
+    "print(processed_test_data.head())\n",
+    "print(processed_val_data.head())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_images = processed_train_data.shape[0]\n",
+    "train_indexes = list(range(train_images))\n",
+    "random.shuffle(train_indexes)\n",
+    "\n",
+    "test_images = processed_test_data.shape[0]\n",
+    "test_indexes = list(range(test_images))\n",
+    "random.shuffle(test_indexes)\n",
+    "\n",
+    "val_images = processed_val_data.shape[0]\n",
+    "val_indexes = list(range(val_images))\n",
+    "random.shuffle(val_indexes)\n",
+    "\n",
+    "train_pattern = os.path.join(base, f\"cuhkpedes-train-%06d.tar\")\n",
+    "test_pattern = os.path.join(base, f\"cuhkpedes-test-%06d.tar\")\n",
+    "val_pattern = os.path.join(base, f\"cuhkpedes-val-%06d.tar\")"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 36,
    "metadata": {},
    "outputs": [],
    "source": [
-    "nimages = preprocess_df.shape[0]\n",
-    "indexes = list(range(nimages))\n",
-    "random.shuffle(indexes)\n",
+    "def readfile(fname):\n",
+    "    \"Read a binary file from disk.\"\n",
+    "    with open(fname, \"rb\") as stream:\n",
+    "        return stream.read()\n",
     "\n",
-    "# pattern = os.path.join(base, f\"cuhkpedes-train-%06d.tar\")\n",
-    "pattern = os.path.join(base, f\"cuhkpedes-val-%06d.tar\")"
+    "train_keys = set()\n",
+    "test_keys = set()\n",
+    "val_keys = set()"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 37,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000000.tar 0 0.0 GB 0\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000001.tar 133 0.0 GB 133\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000002.tar 148 0.0 GB 281\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000003.tar 139 0.0 GB 420\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000004.tar 134 0.0 GB 554\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000005.tar 123 0.0 GB 677\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000006.tar 136 0.0 GB 813\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000007.tar 126 0.0 GB 939\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000008.tar 136 0.0 GB 1075\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000009.tar 151 0.0 GB 1226\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000010.tar 146 0.0 GB 1372\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000011.tar 143 0.0 GB 1515\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000012.tar 146 0.0 GB 1661\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000013.tar 127 0.0 GB 1788\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000014.tar 145 0.0 GB 1933\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000015.tar 135 0.0 GB 2068\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000016.tar 133 0.0 GB 2201\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000017.tar 121 0.0 GB 2322\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000018.tar 120 0.0 GB 2442\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000019.tar 128 0.0 GB 2570\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000020.tar 124 0.0 GB 2694\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000021.tar 115 0.0 GB 2809\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000022.tar 138 0.0 GB 2947\n",
-      "# writing /mnt/vos-s9gjtkm2/reid/dataset/cross_reid/CUHK-PEDES_shards/cuhkpedes-val-000023.tar 128 0.0 GB 3075\n"
+      "# writing /root/dataset/CUHK-PEDES_shards/cuhkpedes-train-000000.tar 0 0.0 GB 0\n",
+      "# writing /root/dataset/CUHK-PEDES_shards/cuhkpedes-train-000001.tar 8060 0.1 GB 8060\n",
+      "# writing /root/dataset/CUHK-PEDES_shards/cuhkpedes-train-000002.tar 8010 0.1 GB 16070\n",
+      "# writing /root/dataset/CUHK-PEDES_shards/cuhkpedes-train-000003.tar 7924 0.1 GB 23994\n",
+      "# writing /root/dataset/CUHK-PEDES_shards/cuhkpedes-train-000004.tar 7933 0.1 GB 31927\n",
+      "# writing /root/dataset/CUHK-PEDES_shards/cuhkpedes-test-000000.tar 0 0.0 GB 0\n",
+      "# writing /root/dataset/CUHK-PEDES_shards/cuhkpedes-val-000000.tar 0 0.0 GB 0\n"
      ]
     }
    ],
    "source": [
-    "def readfile(fname):\n",
-    "    \"Read a binary file from disk.\"\n",
-    "    with open(fname, \"rb\") as stream:\n",
-    "        return stream.read()\n",
+    "def write_to_tar(processed_data, image_path, indexes, all_keys, pattern, maxcount=10000, maxsize=6e7):\n",
     "    \n",
-    "all_keys = set()\n",
+    "    output_dir = os.path.dirname(pattern)\n",
+    "    os.makedirs(output_dir, exist_ok=True)\n",
+    "    with wds.ShardWriter(pattern, maxcount, maxsize) as sink:\n",
+    "        for i in indexes:\n",
+    "            # instance: the file name and the numerical class.\n",
+    "            fname = processed_data.at[i, 'file_path']\n",
+    "            captions = processed_data.at[i, 'captions']\n",
+    "            id = processed_data.at[i, 'id']\n",
+    "            fname = os.path.join(image_path, fname)\n",
     "\n",
-    "with wds.ShardWriter(pattern, maxsize=1000000, maxcount=1000000) as sink:\n",
-    "    for i in indexes:\n",
+    "            # Read the JPEG-compressed image file contents.\n",
+    "            image = readfile(fname)\n",
     "\n",
-    "        # Internal information from the ImageNet dataset\n",
-    "        # instance: the file name and the numerical class.\n",
-    "        fname = preprocess_df.at[i, 'file_path']\n",
-    "        caption = preprocess_df.at[i, 'caption']\n",
-    "        id = preprocess_df.at[i, 'id']\n",
-    "        fname = os.path.join(image_path, fname)\n",
+    "            # Construct a uniqu keye from the filename.\n",
+    "            # base_dir = os.path.dirname(fname)\n",
+    "            # dir_name = os.path.basename(base_dir)\n",
+    "            # key = os.path.splitext(os.path.basename(fname))[0]\n",
+    "            key = f\"{id}_{i}\"\n",
     "\n",
-    "        # Read the JPEG-compressed image file contents.\n",
-    "        image = readfile(fname)\n",
+    "            # Useful check.\n",
+    "            assert key not in all_keys, f\"Conflict detected: Key '{key}' already exists.\"\n",
+    "            all_keys.add(key)   \n",
     "\n",
-    "        # Construct a uniqu keye from the filename.\n",
-    "        base_dir = os.path.dirname(fname)\n",
-    "        dir_name = os.path.basename(base_dir)\n",
-    "        key = os.path.splitext(os.path.basename(fname))[0]\n",
-    "        key = f\"{dir_name}_{key}\"\n",
+    "            # Construct a sample.\n",
+    "            xkey = key if True else \"%07d\" % i\n",
+    "            sample = {\"__key__\": xkey, \"jpg\": image, \"txt\": captions}\n",
     "\n",
-    "        # Useful check.\n",
-    "        assert key not in all_keys, f\"Conflict detected: Key '{key}' already exists.\"\n",
-    "        all_keys.add(key)\n",
+    "            # Write the sample to the sharded tar archives.\n",
+    "            sink.write(sample)\n",
     "\n",
-    "        # Construct the cls field with the new format.\n",
-    "        cls = f\"4 4 1\\n# male_{id} female_{id} person_{id} people_{id}\\n0 1 2 3\"        \n",
     "\n",
-    "        # Construct a sample.\n",
-    "        xkey = key if True else \"%07d\" % i\n",
-    "        sample = {\"__key__\": xkey, \"jpg\": image, \"cls\": cls}\n",
-    "        # sample = {\"__key__\": xkey, \"jpg\": image, \"txt\": caption}\n",
-    "\n",
-    "        # Write the sample to the sharded tar archives.\n",
-    "        sink.write(sample)"
+    "write_to_tar(processed_train_data, image_path, train_indexes, train_keys, train_pattern)\n",
+    "write_to_tar(processed_test_data, image_path, test_indexes, test_keys, test_pattern)\n",
+    "write_to_tar(processed_val_data, image_path, val_indexes, val_keys, val_pattern)"
    ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
   }
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "groupvit",
+   "display_name": "Python 3",
    "language": "python",
    "name": "python3"
   },