|
@@ -0,0 +1,97 @@
|
|
|
+import os
|
|
|
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
|
|
+import torch
|
|
|
+import numpy as np
|
|
|
+import os.path as op
|
|
|
+import torch.nn.functional as F
|
|
|
+from datasets import build_dataloader
|
|
|
+from utils.checkpoint import Checkpointer
|
|
|
+from model import build_model
|
|
|
+from utils.metrics import Evaluator
|
|
|
+from utils.iotools import load_train_configs
|
|
|
+import random
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+from PIL import Image
|
|
|
+from datasets.cuhkpedes import CUHKPEDES
|
|
|
+
|
|
|
+
|
|
|
+config_file = '/xxx/configs.yaml'
|
|
|
+args = load_train_configs(config_file)
|
|
|
+args.batch_size = 1024
|
|
|
+args.training = False
|
|
|
+device = "cuda"
|
|
|
+test_img_loader, test_txt_loader = build_dataloader(args)
|
|
|
+model = build_model(args)
|
|
|
+checkpointer = Checkpointer(model)
|
|
|
+checkpointer.load(f=op.join(args.output_dir, 'best.pth'))
|
|
|
+model.to(device)
|
|
|
+
|
|
|
+evaluator = Evaluator(test_img_loader, test_txt_loader)
|
|
|
+
|
|
|
+qfeats, gfeats, qids, gids = evaluator._compute_embedding(model.eval())
|
|
|
+qfeats = F.normalize(qfeats, p=2, dim=1) # text features
|
|
|
+gfeats = F.normalize(gfeats, p=2, dim=1) # image features
|
|
|
+
|
|
|
+similarity = qfeats @ gfeats.t()
|
|
|
+# acclerate sort with topk
|
|
|
+_, indices = torch.topk(similarity, k=10, dim=1, largest=True, sorted=True) # q * topk
|
|
|
+
|
|
|
+dataset = CUHKPEDES(root='./data')
|
|
|
+test_dataset = dataset.test
|
|
|
+
|
|
|
+img_paths = test_dataset['img_paths']
|
|
|
+captions = test_dataset['captions']
|
|
|
+gt_img_paths = test_dataset['gt_img_paths']
|
|
|
+
|
|
|
+def get_one_query_caption_and_result_by_id(idx, indices, qids, gids, captions, img_paths, gt_img_paths):
|
|
|
+ query_caption = captions[idx]
|
|
|
+ query_id = qids[idx]
|
|
|
+ image_paths = [img_paths[j] for j in indices[idx]]
|
|
|
+ image_ids = gids[indices[idx]]
|
|
|
+ gt_image_path = gt_img_paths[idx]
|
|
|
+ return query_id, image_ids, query_caption, image_paths, gt_image_path
|
|
|
+
|
|
|
+def plot_retrieval_images(query_id, image_ids, query_caption, image_paths, gt_img_path, fname=None):
|
|
|
+ print(query_id)
|
|
|
+ print(image_ids)
|
|
|
+ print(query_caption)
|
|
|
+ fig = plt.figure()
|
|
|
+ col = len(image_paths)
|
|
|
+
|
|
|
+ # plot ground truth image
|
|
|
+ plt.subplot(1, col+1, 1)
|
|
|
+ img = Image.open(gt_img_path)
|
|
|
+ img = img.resize((128, 256))
|
|
|
+ plt.imshow(img)
|
|
|
+ plt.xticks([])
|
|
|
+ plt.yticks([])
|
|
|
+
|
|
|
+ for i in range(col):
|
|
|
+ plt.subplot(1, col+1, i+2)
|
|
|
+ img = Image.open(image_paths[i])
|
|
|
+
|
|
|
+ bwith = 2 # 边框宽度设置为2
|
|
|
+ ax = plt.gca() # 获取边框
|
|
|
+ if image_ids[i] == query_id:
|
|
|
+ ax.spines['top'].set_color('lawngreen')
|
|
|
+ ax.spines['right'].set_color('lawngreen')
|
|
|
+ ax.spines['bottom'].set_color('lawngreen')
|
|
|
+ ax.spines['left'].set_color('lawngreen')
|
|
|
+ else:
|
|
|
+ ax.spines['top'].set_color('red')
|
|
|
+ ax.spines['right'].set_color('red')
|
|
|
+ ax.spines['bottom'].set_color('red')
|
|
|
+ ax.spines['left'].set_color('red')
|
|
|
+
|
|
|
+ img = img.resize((128, 256))
|
|
|
+ plt.imshow(img)
|
|
|
+ plt.xticks([])
|
|
|
+ plt.yticks([])
|
|
|
+
|
|
|
+
|
|
|
+ fig.show()
|
|
|
+ if fname:
|
|
|
+ plt.savefig(fname, dpi=300)
|
|
|
+# idx is the index of qids(A list of query ids, range from 0 - len(qids))
|
|
|
+query_id, image_ids, query_caption, image_paths, gt_img_path = get_one_query_caption_and_result_by_id(0, indices, qids, gids, captions, img_paths, gt_img_paths)
|
|
|
+plot_retrieval_images(query_id, image_ids, query_caption, image_paths, gt_img_path)
|