12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- import os
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
- import torch
- import numpy as np
- import os.path as op
- import torch.nn.functional as F
- from datasets import build_dataloader
- from utils.checkpoint import Checkpointer
- from model import build_model
- from utils.metrics import Evaluator
- from utils.iotools import load_train_configs
- import random
- import matplotlib.pyplot as plt
- from PIL import Image
- from datasets.cuhkpedes import CUHKPEDES
- config_file = '/xxx/configs.yaml'
- args = load_train_configs(config_file)
- args.batch_size = 1024
- args.training = False
- device = "cuda"
- test_img_loader, test_txt_loader = build_dataloader(args)
- model = build_model(args)
- checkpointer = Checkpointer(model)
- checkpointer.load(f=op.join(args.output_dir, 'best.pth'))
- model.to(device)
- evaluator = Evaluator(test_img_loader, test_txt_loader)
- qfeats, gfeats, qids, gids = evaluator._compute_embedding(model.eval())
- qfeats = F.normalize(qfeats, p=2, dim=1) # text features
- gfeats = F.normalize(gfeats, p=2, dim=1) # image features
- similarity = qfeats @ gfeats.t()
- # acclerate sort with topk
- _, indices = torch.topk(similarity, k=10, dim=1, largest=True, sorted=True) # q * topk
- dataset = CUHKPEDES(root='./data')
- test_dataset = dataset.test
- img_paths = test_dataset['img_paths']
- captions = test_dataset['captions']
- gt_img_paths = test_dataset['gt_img_paths']
- def get_one_query_caption_and_result_by_id(idx, indices, qids, gids, captions, img_paths, gt_img_paths):
- query_caption = captions[idx]
- query_id = qids[idx]
- image_paths = [img_paths[j] for j in indices[idx]]
- image_ids = gids[indices[idx]]
- gt_image_path = gt_img_paths[idx]
- return query_id, image_ids, query_caption, image_paths, gt_image_path
- def plot_retrieval_images(query_id, image_ids, query_caption, image_paths, gt_img_path, fname=None):
- print(query_id)
- print(image_ids)
- print(query_caption)
- fig = plt.figure()
- col = len(image_paths)
- # plot ground truth image
- plt.subplot(1, col+1, 1)
- img = Image.open(gt_img_path)
- img = img.resize((128, 256))
- plt.imshow(img)
- plt.xticks([])
- plt.yticks([])
- for i in range(col):
- plt.subplot(1, col+1, i+2)
- img = Image.open(image_paths[i])
- bwith = 2 # 边框宽度设置为2
- ax = plt.gca() # 获取边框
- if image_ids[i] == query_id:
- ax.spines['top'].set_color('lawngreen')
- ax.spines['right'].set_color('lawngreen')
- ax.spines['bottom'].set_color('lawngreen')
- ax.spines['left'].set_color('lawngreen')
- else:
- ax.spines['top'].set_color('red')
- ax.spines['right'].set_color('red')
- ax.spines['bottom'].set_color('red')
- ax.spines['left'].set_color('red')
-
- img = img.resize((128, 256))
- plt.imshow(img)
- plt.xticks([])
- plt.yticks([])
-
- fig.show()
- if fname:
- plt.savefig(fname, dpi=300)
- # idx is the index of qids(A list of query ids, range from 0 - len(qids))
- query_id, image_ids, query_caption, image_paths, gt_img_path = get_one_query_caption_and_result_by_id(0, indices, qids, gids, captions, img_paths, gt_img_paths)
- plot_retrieval_images(query_id, image_ids, query_caption, image_paths, gt_img_path)
|