Selaa lähdekoodia

feat(cuhkpedes): 添加数据集实体标注

- 新增 cuhkpedes_entity_add.ipynb 文件,用于为 CUHK-PEDES 数据集添加实体标注
- 在 cuhkpedes_topk_summarize.ipynb 中更新了执行计数和部分代码内容
Yijun Fu 1 kuukausi sitten
vanhempi
sitoutus
1c9aafb325
2 muutettua tiedostoa jossa 272 lisäystä ja 14 poistoa
  1. 257 0
      cuhkpedes/cuhkpedes_entity_add.ipynb
  2. 15 14
      cuhkpedes/cuhkpedes_topk_summarize.ipynb

+ 257 - 0
cuhkpedes/cuhkpedes_entity_add.ipynb

@@ -0,0 +1,257 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 添加cuhkpedes每句的实体"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import json\n",
+    "import collections\n",
+    "import string\n",
+    "import pandas as pd\n",
+    "import numpy as np\n",
+    "import nltk\n",
+    "from nltk.tokenize import *"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 确保你已经下载了 NLTK 的 punkt 数据"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "punkt tokenizer models are already downloaded.\n"
+     ]
+    }
+   ],
+   "source": [
+    "try:\n",
+    "    sent_tokenize(\"This is a test sentence.\")\n",
+    "    print(\"punkt tokenizer models are already downloaded.\")\n",
+    "except LookupError:\n",
+    "    print(\"punkt tokenizer models are not downloaded.\")\n",
+    "    nltk.download('punkt')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 设置数据集路径"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# 获取主目录\n",
+    "home_directory = os.path.expanduser('~')\n",
+    "dataset_path = os.path.join(home_directory, 'dataset/cross_reid/CUHK-PEDES')\n",
+    "class_file = os.path.join(dataset_path, 'class.json')\n",
+    "raw_file = os.path.join(dataset_path, 'reid_raw.json')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 读取 JSON 文件"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "with open(class_file, 'r') as file:\n",
+    "    class_data = json.load(file)\n",
+    "\n",
+    "class_df = pd.DataFrame(class_data, columns=['class'])\n",
+    "class_df_set = set(class_df['class'])\n",
+    "\n",
+    "with open(raw_file, 'r') as file:\n",
+    "    raw_data = json.load(file)\n",
+    "\n",
+    "raw_df = pd.DataFrame(raw_data)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 添加实体\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "   split                                           captions  \\\n",
+      "0  train  a pedestrian with dark hair is wearing red and...   \n",
+      "1  train  a man wearing a black jacket, black pants, red...   \n",
+      "2  train  the man is wearing a black jacket, green jeans...   \n",
+      "3  train  he's wearing a black hooded sweatshirt with a ...   \n",
+      "4  train  the man is walking.  he is wearing a bright gr...   \n",
+      "\n",
+      "                      file_path  \\\n",
+      "0            CUHK01/0363004.png   \n",
+      "1            CUHK01/0363003.png   \n",
+      "2            CUHK01/0363001.png   \n",
+      "3            CUHK01/0363002.png   \n",
+      "4  train_query/p8130_s10935.jpg   \n",
+      "\n",
+      "                                              entity  \n",
+      "0  hair,person,shoes,sneakers,sweatshirt,pedestri...  \n",
+      "1         man,jacket,hand,shoes,sneakers,shirt,pants  \n",
+      "2  man,jacket,jeans,carrying,backpack,sneakers,pants  \n",
+      "3  man,hair,hoodie,carrying,backpack,shoes,sneake...  \n",
+      "4             man,shoes,sleeved,vest,shirt,hat,pants  \n"
+     ]
+    }
+   ],
+   "source": [
+    "def judge_noun(word):\n",
+    "    if word in class_df_set:\n",
+    "        return 1\n",
+    "    return 0\n",
+    "\n",
+    "def add_entity(items):\n",
+    "    # 合并所有描述并转换为小写\n",
+    "    combined_description = ' '.join(items['captions']).lower()\n",
+    "    items['captions'] = combined_description\n",
+    "    # 分词\n",
+    "    all_words = nltk.word_tokenize(combined_description)\n",
+    "    # 包含的实体\n",
+    "    valid_list = [judge_noun(word) for word in all_words]\n",
+    "    valid = sum(valid_list)\n",
+    "\n",
+    "    if valid:\n",
+    "            valid_words = np.array(all_words)[np.argwhere(valid_list)][:,0].tolist()\n",
+    "            valid_words = list(set(valid_words)) ## keep unique entities\n",
+    "            items['entity'] = ','.join(valid_words)\n",
+    "    \n",
+    "    return items.filter(items=['split', 'captions', 'file_path', 'entity'])\n",
+    "\n",
+    "\n",
+    "entity_added_df = raw_df.apply(add_entity, axis=1)\n",
+    "\n",
+    "print(entity_added_df.head())"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 保存train_entity_add.csv"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "                                            captions           file_path  \\\n",
+      "0  a pedestrian with dark hair is wearing red and...  CUHK01/0363004.png   \n",
+      "\n",
+      "                                              entity  \n",
+      "0  hair,person,shoes,sneakers,sweatshirt,pedestri...  \n",
+      "                                            captions  \\\n",
+      "0  a man wearing a blue and white stripe tank top...   \n",
+      "\n",
+      "                      file_path  \\\n",
+      "0  train_query/p8848_s17661.jpg   \n",
+      "\n",
+      "                                              entity  \n",
+      "0  man,striped,tank,neck,pair,shoes,around,headph...  \n",
+      "                                            captions           file_path  \\\n",
+      "0  the man has short, dark hair and wears khaki p...  CUHK01/0107002.png   \n",
+      "\n",
+      "                                              entity  \n",
+      "0  man,hair,khaki,hoodie,hangs,shoulder,jacket,ba...  \n"
+     ]
+    }
+   ],
+   "source": [
+    "processed_train_data = entity_added_df.loc[entity_added_df['split'] == 'train']\n",
+    "processed_test_data = entity_added_df.loc[entity_added_df['split'] == 'test']\n",
+    "processed_val_data = entity_added_df.loc[entity_added_df['split'] == 'val']\n",
+    "\n",
+    "del processed_train_data['split']\n",
+    "del processed_test_data['split']\n",
+    "del processed_val_data['split']\n",
+    "\n",
+    "# 重置索引并丢弃原始索引\n",
+    "processed_train_data = processed_train_data.reset_index(drop=True)\n",
+    "processed_test_data = processed_test_data.reset_index(drop=True)\n",
+    "processed_val_data = processed_val_data.reset_index(drop=True)\n",
+    "\n",
+    "print(processed_train_data.head(1))\n",
+    "print(processed_test_data.head(1))\n",
+    "print(processed_val_data.head(1))\n",
+    "\n",
+    "processed_train_data.to_csv(f'{dataset_path}/train_entity.csv', index=False)\n",
+    "processed_test_data.to_csv(f'{dataset_path}/test_entity.csv', index=False)\n",
+    "processed_val_data.to_csv(f'{dataset_path}/val_entity.csv', index=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "ovsegmentor",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

+ 15 - 14
cuhkpedes/cuhkpedes_topk_summarize.ipynb

@@ -9,7 +9,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -34,7 +34,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [
     {
@@ -61,7 +61,7 @@
        "True"
       ]
      },
-     "execution_count": 16,
+     "execution_count": 2,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -93,7 +93,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -113,7 +113,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -132,7 +132,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -149,7 +149,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -158,7 +158,8 @@
     "    'wearing', 'long', 'grey', 'also', 'colored', 'back', 'left', 'right', 'small', 'top', 'front', 'bottom',\n",
     "    'long', 'longer', 'longest', 'length', 'side', 'light', 'stripes', 'something', 'tan', 'stripe', 'print',\n",
     "    'picture', 'shopping', 'body', 'design', 'cell', 'color', 'object', 'trim', 'pattern', 'street', 'underneath',\n",
-    "    'soles', 'beige', 'sidewalk', 'cargo', 'leather', 'outfit', 'walks', 'hem', 'walking', 'style'\n",
+    "    'soles', 'beige', 'sidewalk', 'cargo', 'leather', 'outfit', 'walks', 'hem', 'walking', 'style', 'inside',\n",
+    "    'wears', 'item', 'holding', 'carring', 'bright', 'short'\n",
     "])"
    ]
   },
@@ -171,7 +172,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -187,7 +188,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -214,7 +215,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -236,7 +237,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -253,7 +254,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 11,
    "metadata": {},
    "outputs": [
     {
@@ -379,7 +380,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 26,
+   "execution_count": 12,
    "metadata": {},
    "outputs": [
     {