浏览代码

feat(dataset): 统计并保存 CUHK-PEDES 数据集中的 topk 类别

- 读取 CUHK-PEDES 数据集的 JSON 文件
- 使用 NLTK 进行分词和词性标注
- 统计名词出现频率,排除停用词和特定颜色词汇
- 获取出现次数最多的 topk 类别
- 将类别保存为 JSON 文件
Yijun Fu 1 月之前
父节点
当前提交
9b398b8313
共有 1 个文件被更改,包括 426 次插入0 次删除
  1. 426 0
      cuhkpedes/cuhkpedes_topk_summarize.ipynb

+ 426 - 0
cuhkpedes/cuhkpedes_topk_summarize.ipynb

@@ -0,0 +1,426 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 统计cuhkpedes中的topk个类别"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import json\n",
+    "import collections\n",
+    "import nltk\n",
+    "import string\n",
+    "import pandas as pd\n",
+    "from nltk.tokenize import *\n",
+    "from nltk.corpus import stopwords\n",
+    "from nltk import pos_tag\n",
+    "from concurrent.futures import ProcessPoolExecutor"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 确保你已经下载了 NLTK 的 stopwords 数据"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "punkt tokenizer models are already downloaded.\n",
+      "Stopwords are already downloaded.\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "[nltk_data] Downloading package averaged_perceptron_tagger to\n",
+      "[nltk_data]     /mnt/vos-s9gjtkm2/reid/nltk_data...\n",
+      "[nltk_data]   Package averaged_perceptron_tagger is already up-to-\n",
+      "[nltk_data]       date!\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "True"
+      ]
+     },
+     "execution_count": 16,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "try:\n",
+    "    sent_tokenize(\"This is a test sentence.\")\n",
+    "    print(\"punkt tokenizer models are already downloaded.\")\n",
+    "except LookupError:\n",
+    "    print(\"punkt tokenizer models are not downloaded.\")\n",
+    "    nltk.download('punkt')\n",
+    "\n",
+    "try:\n",
+    "    stopwords.words('english')\n",
+    "    print(\"Stopwords are already downloaded.\")\n",
+    "except LookupError:\n",
+    "    print(\"Stopwords are not downloaded.\")\n",
+    "    nltk.download('stopwords')\n",
+    "\n",
+    "nltk.download('averaged_perceptron_tagger')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 设置数据集路径"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# 获取主目录\n",
+    "home_directory = os.path.expanduser('~')\n",
+    "dataset_path = os.path.join(home_directory, 'dataset/cross_reid/CUHK-PEDES')\n",
+    "json_file = os.path.join(dataset_path, 'reid_raw.json')\n",
+    "output_json_file = os.path.join(dataset_path, 'class.json')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 读取 JSON 文件"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "with open(json_file, 'r') as file:\n",
+    "    data = json.load(file)\n",
+    "\n",
+    "data_df = pd.DataFrame(data)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 定义停用词和标点符号\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "stop_words = set(stopwords.words('english'))\n",
+    "punctuations = set(string.punctuation)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 定义要排除的颜色和其他非实体类词汇"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "exclude_words = set([\n",
+    "    'black', 'white', 'dark', 'red', 'blue', 'green', 'yellow', 'brown', 'gray', 'pink', 'purple', 'orange',\n",
+    "    'wearing', 'long', 'grey', 'also', 'colored', 'back', 'left', 'right', 'small', 'top', 'front', 'bottom',\n",
+    "    'long', 'longer', 'longest', 'length', 'side', 'light', 'stripes', 'something', 'tan', 'stripe', 'print',\n",
+    "    'picture', 'shopping', 'body', 'design', 'cell', 'color', 'object', 'trim', 'pattern', 'street', 'underneath',\n",
+    "    'soles', 'beige', 'sidewalk', 'cargo', 'leather', 'outfit', 'walks', 'hem', 'walking', 'style'\n",
+    "])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 初始化一个计数器来统计类别出现频率"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class_counter = collections.Counter()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 定义处理单个描述的函数"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def process_description(captions):\n",
+    "    # 合并所有描述并转换为小写\n",
+    "    combined_description = ' '.join(captions).lower()\n",
+    "    # 分词\n",
+    "    words = word_tokenize(combined_description)\n",
+    "    # 词性标注\n",
+    "    tagged_words = pos_tag(words)\n",
+    "    # 提取名词\n",
+    "    nouns = [word for word, pos in tagged_words if pos.startswith('NN')]\n",
+    "    # 去除停用词、标点符号和排除词汇\n",
+    "    nouns = [word for word in nouns if word not in stop_words and word not in punctuations and word not in exclude_words]\n",
+    "    return nouns"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 并行处理与合并结果"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# 使用多进程并行处理描述\n",
+    "with ProcessPoolExecutor() as executor:\n",
+    "    results = list(executor.map(process_description, data_df['captions']))\n",
+    "\n",
+    "# 合并结果并更新计数器\n",
+    "for nouns in results:\n",
+    "    class_counter.update(nouns)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 了解K个类别"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "top_k = 100  # 你可以根据需要调整这个值\n",
+    "top_k_classes = class_counter.most_common(top_k)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 打印结果"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Top K classes in CUHK-PEDES descriptions:\n",
+      "shirt: 46183\n",
+      "shoes: 36592\n",
+      "man: 33758\n",
+      "pants: 31576\n",
+      "pair: 28255\n",
+      "woman: 26066\n",
+      "hair: 25314\n",
+      "shorts: 17841\n",
+      "bag: 15115\n",
+      "jeans: 11191\n",
+      "jacket: 10212\n",
+      "shoulder: 9120\n",
+      "backpack: 8838\n",
+      "hand: 6999\n",
+      "t-shirt: 6945\n",
+      "glasses: 6663\n",
+      "dress: 6484\n",
+      "sneakers: 6459\n",
+      "sandals: 4569\n",
+      "skirt: 4567\n",
+      "person: 4452\n",
+      "sleeve: 3968\n",
+      "tennis: 3860\n",
+      "purse: 3848\n",
+      "sleeves: 2945\n",
+      "lady: 2749\n",
+      "girl: 2678\n",
+      "pack: 2550\n",
+      "coat: 2337\n",
+      "sweater: 2318\n",
+      "arm: 2300\n",
+      "socks: 2096\n",
+      "boots: 2057\n",
+      "phone: 1879\n",
+      "male: 1759\n",
+      "hat: 1690\n",
+      "hands: 1625\n",
+      "shoulders: 1571\n",
+      "jean: 1554\n",
+      "blouse: 1451\n",
+      "tee: 1423\n",
+      "collar: 1412\n",
+      "watch: 1406\n",
+      "boy: 1362\n",
+      "head: 1301\n",
+      "ponytail: 1274\n",
+      "leggings: 1269\n",
+      "knee: 1195\n",
+      "belt: 1172\n",
+      "sweatshirt: 1159\n",
+      "denim: 1137\n",
+      "tank: 1135\n",
+      "neck: 1134\n",
+      "suit: 1120\n",
+      "button: 1115\n",
+      "feet: 1108\n",
+      "strap: 1065\n",
+      "polo: 1062\n",
+      "plaid: 1050\n",
+      "wrist: 1046\n",
+      "slacks: 1038\n",
+      "tie: 954\n",
+      "camera: 901\n",
+      "female: 900\n",
+      "umbrella: 888\n",
+      "heels: 881\n",
+      "khaki: 875\n",
+      "knees: 871\n",
+      "vest: 868\n",
+      "tail: 855\n",
+      "flops: 855\n",
+      "sleeveless: 833\n",
+      "pony: 823\n",
+      "hoodie: 817\n",
+      "cap: 816\n",
+      "straps: 799\n",
+      "book: 789\n",
+      "arms: 768\n",
+      "sunglasses: 766\n",
+      "waist: 751\n",
+      "handbag: 735\n",
+      "hood: 685\n",
+      "logo: 676\n",
+      "tote: 660\n",
+      "flip: 655\n",
+      "chest: 655\n",
+      "scarf: 654\n",
+      "bags: 579\n",
+      "leg: 572\n",
+      "pocket: 570\n",
+      "women: 535\n",
+      "adult: 508\n",
+      "face: 492\n",
+      "tights: 472\n",
+      "capri: 450\n",
+      "child: 437\n",
+      "tshirt: 436\n",
+      "gold: 425\n",
+      "paper: 423\n",
+      "messenger: 392\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(\"Top K classes in CUHK-PEDES descriptions:\")\n",
+    "for cls, count in top_k_classes:\n",
+    "    print(f\"{cls}: {count}\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 存储结果"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "类别已保存为 categories.json 文件\n"
+     ]
+    }
+   ],
+   "source": [
+    "# 提取类别\n",
+    "categories = list(class_counter.keys())\n",
+    "\n",
+    "# 将类别保存为 JSON 文件\n",
+    "with open(output_json_file, 'w') as json_file:\n",
+    "    json.dump(categories, json_file)\n",
+    "\n",
+    "print(\"类别已保存为 categories.json 文件\")"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "ovsegmentor",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}