# 统计cuhkpedes中的topk个类别

In [15]:
import os
import json
import collections
import nltk
import string
import pandas as pd
from nltk.tokenize import *
from nltk.corpus import stopwords
from nltk import pos_tag
from concurrent.futures import ProcessPoolExecutor

# 确保你已经下载了 NLTK 的 stopwords 数据

In [16]:
try:
 sent_tokenize("This is a test sentence.")
 print("punkt tokenizer models are already downloaded.")
except LookupError:
 print("punkt tokenizer models are not downloaded.")
 nltk.download('punkt')

try:
 stopwords.words('english')
 print("Stopwords are already downloaded.")
except LookupError:
 print("Stopwords are not downloaded.")
 nltk.download('stopwords')

nltk.download('averaged_perceptron_tagger')

punkt tokenizer models are already downloaded.
Stopwords are already downloaded.


[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data] /mnt/vos-s9gjtkm2/reid/nltk_data...
[nltk_data] Package averaged_perceptron_tagger is already up-to-
[nltk_data] date!


True

# 设置数据集路径

In [17]:
# 获取主目录
home_directory = os.path.expanduser('~')
dataset_path = os.path.join(home_directory, 'dataset/cross_reid/CUHK-PEDES')
json_file = os.path.join(dataset_path, 'reid_raw.json')
output_json_file = os.path.join(dataset_path, 'class.json')

# 读取 JSON 文件

In [18]:
with open(json_file, 'r') as file:
 data = json.load(file)

data_df = pd.DataFrame(data)

# 定义停用词和标点符号


In [19]:
stop_words = set(stopwords.words('english'))
punctuations = set(string.punctuation)

# 定义要排除的颜色和其他非实体类词汇

In [20]:
exclude_words = set([
 'black', 'white', 'dark', 'red', 'blue', 'green', 'yellow', 'brown', 'gray', 'pink', 'purple', 'orange',
 'wearing', 'long', 'grey', 'also', 'colored', 'back', 'left', 'right', 'small', 'top', 'front', 'bottom',
 'long', 'longer', 'longest', 'length', 'side', 'light', 'stripes', 'something', 'tan', 'stripe', 'print',
 'picture', 'shopping', 'body', 'design', 'cell', 'color', 'object', 'trim', 'pattern', 'street', 'underneath',
 'soles', 'beige', 'sidewalk', 'cargo', 'leather', 'outfit', 'walks', 'hem', 'walking', 'style'
])

# 初始化一个计数器来统计类别出现频率

In [21]:
class_counter = collections.Counter()

# 定义处理单个描述的函数

In [22]:
def process_description(captions):
 # 合并所有描述并转换为小写
 combined_description = ' '.join(captions).lower()
 # 分词
 words = word_tokenize(combined_description)
 # 词性标注
 tagged_words = pos_tag(words)
 # 提取名词
 nouns = [word for word, pos in tagged_words if pos.startswith('NN')]
 # 去除停用词、标点符号和排除词汇
 nouns = [word for word in nouns if word not in stop_words and word not in punctuations and word not in exclude_words]
 return nouns

# 并行处理与合并结果

In [23]:
# 使用多进程并行处理描述
with ProcessPoolExecutor() as executor:
 results = list(executor.map(process_description, data_df['captions']))

# 合并结果并更新计数器
for nouns in results:
 class_counter.update(nouns)

# 了解K个类别

In [24]:
top_k = 100 # 你可以根据需要调整这个值
top_k_classes = class_counter.most_common(top_k)

# 打印结果

In [25]:
print("Top K classes in CUHK-PEDES descriptions:")
for cls, count in top_k_classes:
 print(f"{cls}: {count}")

Top K classes in CUHK-PEDES descriptions:
shirt: 46183
shoes: 36592
man: 33758
pants: 31576
pair: 28255
woman: 26066
hair: 25314
shorts: 17841
bag: 15115
jeans: 11191
jacket: 10212
shoulder: 9120
backpack: 8838
hand: 6999
t-shirt: 6945
glasses: 6663
dress: 6484
sneakers: 6459
sandals: 4569
skirt: 4567
person: 4452
sleeve: 3968
tennis: 3860
purse: 3848
sleeves: 2945
lady: 2749
girl: 2678
pack: 2550
coat: 2337
sweater: 2318
arm: 2300
socks: 2096
boots: 2057
phone: 1879
male: 1759
hat: 1690
hands: 1625
shoulders: 1571
jean: 1554
blouse: 1451
tee: 1423
collar: 1412
watch: 1406
boy: 1362
head: 1301
ponytail: 1274
leggings: 1269
knee: 1195
belt: 1172
sweatshirt: 1159
denim: 1137
tank: 1135
neck: 1134
suit: 1120
button: 1115
feet: 1108
strap: 1065
polo: 1062
plaid: 1050
wrist: 1046
slacks: 1038
tie: 954
camera: 901
female: 900
umbrella: 888
heels: 881
khaki: 875
knees: 871
vest: 868
tail: 855
flops: 855
sleeveless: 833
pony: 823
hoodie: 817
cap: 816
straps: 799
book: 789
arms: 768
sunglasses:

# 存储结果

In [26]:
# 提取类别
categories = list(class_counter.keys())

# 将类别保存为 JSON 文件
with open(output_json_file, 'w') as json_file:
 json.dump(categories, json_file)

print("类别已保存为 categories.json 文件")

类别已保存为 categories.json 文件
