visualize.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import os
  2. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  3. import torch
  4. import numpy as np
  5. import os.path as op
  6. import torch.nn.functional as F
  7. from datasets import build_dataloader
  8. from utils.checkpoint import Checkpointer
  9. from model import build_model
  10. from utils.metrics import Evaluator
  11. from utils.iotools import load_train_configs
  12. import random
  13. import matplotlib.pyplot as plt
  14. from PIL import Image
  15. from datasets.cuhkpedes import CUHKPEDES
  16. config_file = '/xxx/configs.yaml'
  17. args = load_train_configs(config_file)
  18. args.batch_size = 1024
  19. args.training = False
  20. device = "cuda"
  21. test_img_loader, test_txt_loader = build_dataloader(args)
  22. model = build_model(args)
  23. checkpointer = Checkpointer(model)
  24. checkpointer.load(f=op.join(args.output_dir, 'best.pth'))
  25. model.to(device)
  26. evaluator = Evaluator(test_img_loader, test_txt_loader)
  27. qfeats, gfeats, qids, gids = evaluator._compute_embedding(model.eval())
  28. qfeats = F.normalize(qfeats, p=2, dim=1) # text features
  29. gfeats = F.normalize(gfeats, p=2, dim=1) # image features
  30. similarity = qfeats @ gfeats.t()
  31. # acclerate sort with topk
  32. _, indices = torch.topk(similarity, k=10, dim=1, largest=True, sorted=True) # q * topk
  33. dataset = CUHKPEDES(root='./data')
  34. test_dataset = dataset.test
  35. img_paths = test_dataset['img_paths']
  36. captions = test_dataset['captions']
  37. gt_img_paths = test_dataset['gt_img_paths']
  38. def get_one_query_caption_and_result_by_id(idx, indices, qids, gids, captions, img_paths, gt_img_paths):
  39. query_caption = captions[idx]
  40. query_id = qids[idx]
  41. image_paths = [img_paths[j] for j in indices[idx]]
  42. image_ids = gids[indices[idx]]
  43. gt_image_path = gt_img_paths[idx]
  44. return query_id, image_ids, query_caption, image_paths, gt_image_path
  45. def plot_retrieval_images(query_id, image_ids, query_caption, image_paths, gt_img_path, fname=None):
  46. print(query_id)
  47. print(image_ids)
  48. print(query_caption)
  49. fig = plt.figure()
  50. col = len(image_paths)
  51. # plot ground truth image
  52. plt.subplot(1, col+1, 1)
  53. img = Image.open(gt_img_path)
  54. img = img.resize((128, 256))
  55. plt.imshow(img)
  56. plt.xticks([])
  57. plt.yticks([])
  58. for i in range(col):
  59. plt.subplot(1, col+1, i+2)
  60. img = Image.open(image_paths[i])
  61. bwith = 2 # 边框宽度设置为2
  62. ax = plt.gca() # 获取边框
  63. if image_ids[i] == query_id:
  64. ax.spines['top'].set_color('lawngreen')
  65. ax.spines['right'].set_color('lawngreen')
  66. ax.spines['bottom'].set_color('lawngreen')
  67. ax.spines['left'].set_color('lawngreen')
  68. else:
  69. ax.spines['top'].set_color('red')
  70. ax.spines['right'].set_color('red')
  71. ax.spines['bottom'].set_color('red')
  72. ax.spines['left'].set_color('red')
  73. img = img.resize((128, 256))
  74. plt.imshow(img)
  75. plt.xticks([])
  76. plt.yticks([])
  77. fig.show()
  78. if fname:
  79. plt.savefig(fname, dpi=300)
  80. # idx is the index of qids(A list of query ids, range from 0 - len(qids))
  81. query_id, image_ids, query_caption, image_paths, gt_img_path = get_one_query_caption_and_result_by_id(0, indices, qids, gids, captions, img_paths, gt_img_paths)
  82. plot_retrieval_images(query_id, image_ids, query_caption, image_paths, gt_img_path)