|
@@ -0,0 +1,257 @@
|
|
|
|
+{
|
|
|
|
+ "cells": [
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# 添加cuhkpedes每句的实体"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 1,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "import os\n",
|
|
|
|
+ "import json\n",
|
|
|
|
+ "import collections\n",
|
|
|
|
+ "import string\n",
|
|
|
|
+ "import pandas as pd\n",
|
|
|
|
+ "import numpy as np\n",
|
|
|
|
+ "import nltk\n",
|
|
|
|
+ "from nltk.tokenize import *"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# 确保你已经下载了 NLTK 的 punkt 数据"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 2,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "punkt tokenizer models are already downloaded.\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "try:\n",
|
|
|
|
+ " sent_tokenize(\"This is a test sentence.\")\n",
|
|
|
|
+ " print(\"punkt tokenizer models are already downloaded.\")\n",
|
|
|
|
+ "except LookupError:\n",
|
|
|
|
+ " print(\"punkt tokenizer models are not downloaded.\")\n",
|
|
|
|
+ " nltk.download('punkt')"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# 设置数据集路径"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "# 获取主目录\n",
|
|
|
|
+ "home_directory = os.path.expanduser('~')\n",
|
|
|
|
+ "dataset_path = os.path.join(home_directory, 'dataset/cross_reid/CUHK-PEDES')\n",
|
|
|
|
+ "class_file = os.path.join(dataset_path, 'class.json')\n",
|
|
|
|
+ "raw_file = os.path.join(dataset_path, 'reid_raw.json')"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# 读取 JSON 文件"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 4,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "with open(class_file, 'r') as file:\n",
|
|
|
|
+ " class_data = json.load(file)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "class_df = pd.DataFrame(class_data, columns=['class'])\n",
|
|
|
|
+ "class_df_set = set(class_df['class'])\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "with open(raw_file, 'r') as file:\n",
|
|
|
|
+ " raw_data = json.load(file)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "raw_df = pd.DataFrame(raw_data)"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# 添加实体\n"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 7,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ " split captions \\\n",
|
|
|
|
+ "0 train a pedestrian with dark hair is wearing red and... \n",
|
|
|
|
+ "1 train a man wearing a black jacket, black pants, red... \n",
|
|
|
|
+ "2 train the man is wearing a black jacket, green jeans... \n",
|
|
|
|
+ "3 train he's wearing a black hooded sweatshirt with a ... \n",
|
|
|
|
+ "4 train the man is walking. he is wearing a bright gr... \n",
|
|
|
|
+ "\n",
|
|
|
|
+ " file_path \\\n",
|
|
|
|
+ "0 CUHK01/0363004.png \n",
|
|
|
|
+ "1 CUHK01/0363003.png \n",
|
|
|
|
+ "2 CUHK01/0363001.png \n",
|
|
|
|
+ "3 CUHK01/0363002.png \n",
|
|
|
|
+ "4 train_query/p8130_s10935.jpg \n",
|
|
|
|
+ "\n",
|
|
|
|
+ " entity \n",
|
|
|
|
+ "0 hair,person,shoes,sneakers,sweatshirt,pedestri... \n",
|
|
|
|
+ "1 man,jacket,hand,shoes,sneakers,shirt,pants \n",
|
|
|
|
+ "2 man,jacket,jeans,carrying,backpack,sneakers,pants \n",
|
|
|
|
+ "3 man,hair,hoodie,carrying,backpack,shoes,sneake... \n",
|
|
|
|
+ "4 man,shoes,sleeved,vest,shirt,hat,pants \n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "def judge_noun(word):\n",
|
|
|
|
+ " if word in class_df_set:\n",
|
|
|
|
+ " return 1\n",
|
|
|
|
+ " return 0\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "def add_entity(items):\n",
|
|
|
|
+ " # 合并所有描述并转换为小写\n",
|
|
|
|
+ " combined_description = ' '.join(items['captions']).lower()\n",
|
|
|
|
+ " items['captions'] = combined_description\n",
|
|
|
|
+ " # 分词\n",
|
|
|
|
+ " all_words = nltk.word_tokenize(combined_description)\n",
|
|
|
|
+ " # 包含的实体\n",
|
|
|
|
+ " valid_list = [judge_noun(word) for word in all_words]\n",
|
|
|
|
+ " valid = sum(valid_list)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " if valid:\n",
|
|
|
|
+ " valid_words = np.array(all_words)[np.argwhere(valid_list)][:,0].tolist()\n",
|
|
|
|
+ " valid_words = list(set(valid_words)) ## keep unique entities\n",
|
|
|
|
+ " items['entity'] = ','.join(valid_words)\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " return items.filter(items=['split', 'captions', 'file_path', 'entity'])\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "entity_added_df = raw_df.apply(add_entity, axis=1)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "print(entity_added_df.head())"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# 保存train_entity_add.csv"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 17,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ " captions file_path \\\n",
|
|
|
|
+ "0 a pedestrian with dark hair is wearing red and... CUHK01/0363004.png \n",
|
|
|
|
+ "\n",
|
|
|
|
+ " entity \n",
|
|
|
|
+ "0 hair,person,shoes,sneakers,sweatshirt,pedestri... \n",
|
|
|
|
+ " captions \\\n",
|
|
|
|
+ "0 a man wearing a blue and white stripe tank top... \n",
|
|
|
|
+ "\n",
|
|
|
|
+ " file_path \\\n",
|
|
|
|
+ "0 train_query/p8848_s17661.jpg \n",
|
|
|
|
+ "\n",
|
|
|
|
+ " entity \n",
|
|
|
|
+ "0 man,striped,tank,neck,pair,shoes,around,headph... \n",
|
|
|
|
+ " captions file_path \\\n",
|
|
|
|
+ "0 the man has short, dark hair and wears khaki p... CUHK01/0107002.png \n",
|
|
|
|
+ "\n",
|
|
|
|
+ " entity \n",
|
|
|
|
+ "0 man,hair,khaki,hoodie,hangs,shoulder,jacket,ba... \n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "processed_train_data = entity_added_df.loc[entity_added_df['split'] == 'train']\n",
|
|
|
|
+ "processed_test_data = entity_added_df.loc[entity_added_df['split'] == 'test']\n",
|
|
|
|
+ "processed_val_data = entity_added_df.loc[entity_added_df['split'] == 'val']\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "del processed_train_data['split']\n",
|
|
|
|
+ "del processed_test_data['split']\n",
|
|
|
|
+ "del processed_val_data['split']\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# 重置索引并丢弃原始索引\n",
|
|
|
|
+ "processed_train_data = processed_train_data.reset_index(drop=True)\n",
|
|
|
|
+ "processed_test_data = processed_test_data.reset_index(drop=True)\n",
|
|
|
|
+ "processed_val_data = processed_val_data.reset_index(drop=True)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "print(processed_train_data.head(1))\n",
|
|
|
|
+ "print(processed_test_data.head(1))\n",
|
|
|
|
+ "print(processed_val_data.head(1))\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "processed_train_data.to_csv(f'{dataset_path}/train_entity.csv', index=False)\n",
|
|
|
|
+ "processed_test_data.to_csv(f'{dataset_path}/test_entity.csv', index=False)\n",
|
|
|
|
+ "processed_val_data.to_csv(f'{dataset_path}/val_entity.csv', index=False)"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": []
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "metadata": {
|
|
|
|
+ "kernelspec": {
|
|
|
|
+ "display_name": "ovsegmentor",
|
|
|
|
+ "language": "python",
|
|
|
|
+ "name": "python3"
|
|
|
|
+ },
|
|
|
|
+ "language_info": {
|
|
|
|
+ "codemirror_mode": {
|
|
|
|
+ "name": "ipython",
|
|
|
|
+ "version": 3
|
|
|
|
+ },
|
|
|
|
+ "file_extension": ".py",
|
|
|
|
+ "mimetype": "text/x-python",
|
|
|
|
+ "name": "python",
|
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
|
+ "version": "3.10.4"
|
|
|
|
+ }
|
|
|
|
+ },
|
|
|
|
+ "nbformat": 4,
|
|
|
|
+ "nbformat_minor": 2
|
|
|
|
+}
|