{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 统计cuhkpedes中的topk个类别" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import collections\n", "import nltk\n", "import string\n", "import pandas as pd\n", "from nltk.tokenize import *\n", "from nltk.corpus import stopwords\n", "from nltk import pos_tag\n", "from concurrent.futures import ProcessPoolExecutor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 确保你已经下载了 NLTK 的 stopwords 数据" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "punkt tokenizer models are already downloaded.\n", "Stopwords are already downloaded.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[nltk_data] Downloading package averaged_perceptron_tagger to\n", "[nltk_data] /mnt/vos-s9gjtkm2/reid/nltk_data...\n", "[nltk_data] Package averaged_perceptron_tagger is already up-to-\n", "[nltk_data] date!\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "try:\n", " sent_tokenize(\"This is a test sentence.\")\n", " print(\"punkt tokenizer models are already downloaded.\")\n", "except LookupError:\n", " print(\"punkt tokenizer models are not downloaded.\")\n", " nltk.download('punkt')\n", "\n", "try:\n", " stopwords.words('english')\n", " print(\"Stopwords are already downloaded.\")\n", "except LookupError:\n", " print(\"Stopwords are not downloaded.\")\n", " nltk.download('stopwords')\n", "\n", "nltk.download('averaged_perceptron_tagger')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 设置数据集路径" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# 获取主目录\n", "home_directory = os.path.expanduser('~')\n", "dataset_path = os.path.join(home_directory, 'dataset/cross_reid/CUHK-PEDES')\n", "json_file = os.path.join(dataset_path, 'reid_raw.json')\n", "output_json_file = os.path.join(dataset_path, 'class.json')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 读取 JSON 文件" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "with open(json_file, 'r') as file:\n", " data = json.load(file)\n", "\n", "data_df = pd.DataFrame(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 定义停用词和标点符号\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "stop_words = set(stopwords.words('english'))\n", "punctuations = set(string.punctuation)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 定义要排除的颜色和其他非实体类词汇" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "exclude_words = set([\n", " 'black', 'white', 'dark', 'red', 'blue', 'green', 'yellow', 'brown', 'gray', 'pink', 'purple', 'orange',\n", " 'wearing', 'long', 'grey', 'also', 'colored', 'back', 'left', 'right', 'small', 'top', 'front', 'bottom',\n", " 'long', 'longer', 'longest', 'length', 'side', 'light', 'stripes', 'something', 'tan', 'stripe', 'print',\n", " 'picture', 'shopping', 'body', 'design', 'cell', 'color', 'object', 'trim', 'pattern', 'street', 'underneath',\n", " 'soles', 'beige', 'sidewalk', 'cargo', 'leather', 'outfit', 'walks', 'hem', 'walking', 'style'\n", "])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 初始化一个计数器来统计类别出现频率" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "class_counter = collections.Counter()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 定义处理单个描述的函数" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "def process_description(captions):\n", " # 合并所有描述并转换为小写\n", " combined_description = ' '.join(captions).lower()\n", " # 分词\n", " words = word_tokenize(combined_description)\n", " # 词性标注\n", " tagged_words = pos_tag(words)\n", " # 提取名词\n", " nouns = [word for word, pos in tagged_words if pos.startswith('NN')]\n", " # 去除停用词、标点符号和排除词汇\n", " nouns = [word for word in nouns if word not in stop_words and word not in punctuations and word not in exclude_words]\n", " return nouns" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 并行处理与合并结果" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# 使用多进程并行处理描述\n", "with ProcessPoolExecutor() as executor:\n", " results = list(executor.map(process_description, data_df['captions']))\n", "\n", "# 合并结果并更新计数器\n", "for nouns in results:\n", " class_counter.update(nouns)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 了解K个类别" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "top_k = 100 # 你可以根据需要调整这个值\n", "top_k_classes = class_counter.most_common(top_k)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 打印结果" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Top K classes in CUHK-PEDES descriptions:\n", "shirt: 46183\n", "shoes: 36592\n", "man: 33758\n", "pants: 31576\n", "pair: 28255\n", "woman: 26066\n", "hair: 25314\n", "shorts: 17841\n", "bag: 15115\n", "jeans: 11191\n", "jacket: 10212\n", "shoulder: 9120\n", "backpack: 8838\n", "hand: 6999\n", "t-shirt: 6945\n", "glasses: 6663\n", "dress: 6484\n", "sneakers: 6459\n", "sandals: 4569\n", "skirt: 4567\n", "person: 4452\n", "sleeve: 3968\n", "tennis: 3860\n", "purse: 3848\n", "sleeves: 2945\n", "lady: 2749\n", "girl: 2678\n", "pack: 2550\n", "coat: 2337\n", "sweater: 2318\n", "arm: 2300\n", "socks: 2096\n", "boots: 2057\n", "phone: 1879\n", "male: 1759\n", "hat: 1690\n", "hands: 1625\n", "shoulders: 1571\n", "jean: 1554\n", "blouse: 1451\n", "tee: 1423\n", "collar: 1412\n", "watch: 1406\n", "boy: 1362\n", "head: 1301\n", "ponytail: 1274\n", "leggings: 1269\n", "knee: 1195\n", "belt: 1172\n", "sweatshirt: 1159\n", "denim: 1137\n", "tank: 1135\n", "neck: 1134\n", "suit: 1120\n", "button: 1115\n", "feet: 1108\n", "strap: 1065\n", "polo: 1062\n", "plaid: 1050\n", "wrist: 1046\n", "slacks: 1038\n", "tie: 954\n", "camera: 901\n", "female: 900\n", "umbrella: 888\n", "heels: 881\n", "khaki: 875\n", "knees: 871\n", "vest: 868\n", "tail: 855\n", "flops: 855\n", "sleeveless: 833\n", "pony: 823\n", "hoodie: 817\n", "cap: 816\n", "straps: 799\n", "book: 789\n", "arms: 768\n", "sunglasses: 766\n", "waist: 751\n", "handbag: 735\n", "hood: 685\n", "logo: 676\n", "tote: 660\n", "flip: 655\n", "chest: 655\n", "scarf: 654\n", "bags: 579\n", "leg: 572\n", "pocket: 570\n", "women: 535\n", "adult: 508\n", "face: 492\n", "tights: 472\n", "capri: 450\n", "child: 437\n", "tshirt: 436\n", "gold: 425\n", "paper: 423\n", "messenger: 392\n" ] } ], "source": [ "print(\"Top K classes in CUHK-PEDES descriptions:\")\n", "for cls, count in top_k_classes:\n", " print(f\"{cls}: {count}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 存储结果" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "类别已保存为 categories.json 文件\n" ] } ], "source": [ "# 提取类别\n", "categories = list(class_counter.keys())\n", "\n", "# 将类别保存为 JSON 文件\n", "with open(output_json_file, 'w') as json_file:\n", " json.dump(categories, json_file)\n", "\n", "print(\"类别已保存为 categories.json 文件\")" ] } ], "metadata": { "kernelspec": { "display_name": "ovsegmentor", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 2 }