metrics.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from prettytable import PrettyTable
  2. import torch
  3. import numpy as np
  4. import os
  5. import torch.nn.functional as F
  6. import logging
  7. import diffdist.functional as diff_dist
  8. import torch.distributed as dist
  9. def dist_collect(x):
  10. """ collect all tensor from all GPUs
  11. args:
  12. x: shape (mini_batch, ...)
  13. returns:
  14. shape (mini_batch * num_gpu, ...)
  15. """
  16. x = x.contiguous()
  17. out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
  18. out_list = diff_dist.all_gather(out_list, x)
  19. return torch.cat(out_list, dim=0).contiguous()
  20. def rank(similarity, q_pids, g_pids, max_rank=10, get_mAP=True):
  21. if get_mAP:
  22. indices = torch.argsort(similarity, dim=1, descending=True)
  23. else:
  24. # acclerate sort with topk
  25. _, indices = torch.topk(
  26. similarity, k=max_rank, dim=1, largest=True, sorted=True
  27. ) # q * topk
  28. pred_labels = g_pids[indices.cpu()] # q * k
  29. matches = pred_labels.eq(q_pids.view(-1, 1)) # q * k
  30. all_cmc = matches[:, :max_rank].cumsum(1) # cumulative sum
  31. all_cmc[all_cmc > 1] = 1
  32. all_cmc = all_cmc.float().mean(0) * 100
  33. # all_cmc = all_cmc[topk - 1]
  34. if not get_mAP:
  35. return all_cmc, indices
  36. num_rel = matches.sum(1) # q
  37. tmp_cmc = matches.cumsum(1) # q * k
  38. inp = [tmp_cmc[i][match_row.nonzero()[-1]] / (match_row.nonzero()[-1] + 1.) for i, match_row in enumerate(matches)]
  39. mINP = torch.cat(inp).mean() * 100
  40. tmp_cmc = [tmp_cmc[:, i] / (i + 1.0) for i in range(tmp_cmc.shape[1])]
  41. tmp_cmc = torch.stack(tmp_cmc, 1) * matches
  42. AP = tmp_cmc.sum(1) / num_rel # q
  43. mAP = AP.mean() * 100
  44. return all_cmc, mAP, mINP, indices
  45. class Evaluator():
  46. def __init__(self, img_loader, txt_loader):
  47. self.img_loader = img_loader # gallery
  48. self.txt_loader = txt_loader # query
  49. self.logger = logging.getLogger("GroupViT_irra.eval")
  50. def _compute_embedding(self, model):
  51. model = model.eval()
  52. device = next(model.parameters()).device
  53. qids, gids, qfeats, gfeats = [], [], [], []
  54. # text
  55. for pid, caption in self.txt_loader:
  56. # print('pid', pid.shape[0])
  57. # print('caption: ', caption.shape[0])
  58. caption = caption.to(device)
  59. with torch.no_grad():
  60. text_outs = model.encode_text(caption, as_dict=True)
  61. # [B, C]
  62. text_x = text_outs['text_x']
  63. text_x = F.normalize(text_x, dim=-1)
  64. # text_feat = text_x @ dist_collect(image_x).t()
  65. qids.append(pid.view(-1)) # flatten
  66. qfeats.append(text_x)
  67. qids = torch.cat(qids, 0)
  68. qfeats = torch.cat(qfeats, 0)
  69. # image
  70. for pid, img in self.img_loader:
  71. img = img.to(device)
  72. with torch.no_grad():
  73. image_outs = model.encode_image(img, as_dict=True)
  74. # [B, C]
  75. image_x = image_outs['image_x']
  76. image_x = F.normalize(image_x, dim=-1)
  77. # img_feat = image_x @ dist_collect(text_x).t()
  78. gids.append(pid.view(-1)) # flatten
  79. gfeats.append(image_x)
  80. gids = torch.cat(gids, 0)
  81. gfeats = torch.cat(gfeats, 0)
  82. return qfeats, gfeats, qids, gids
  83. def eval(self, model, i2t_metric=False):
  84. qfeats, gfeats, qids, gids = self._compute_embedding(model)
  85. qfeats = F.normalize(qfeats, p=2, dim=1) # text features
  86. gfeats = F.normalize(gfeats, p=2, dim=1) # image features
  87. similarity = qfeats @ gfeats.t()
  88. t2i_cmc, t2i_mAP, t2i_mINP, _ = rank(similarity=similarity, q_pids=qids, g_pids=gids, max_rank=10, get_mAP=True)
  89. t2i_cmc, t2i_mAP, t2i_mINP = t2i_cmc.numpy(), t2i_mAP.numpy(), t2i_mINP.numpy()
  90. table = PrettyTable(["task", "R1", "R5", "R10", "mAP", "mINP"])
  91. table.add_row(['t2i', t2i_cmc[0], t2i_cmc[4], t2i_cmc[9], t2i_mAP, t2i_mINP])
  92. if i2t_metric:
  93. i2t_cmc, i2t_mAP, i2t_mINP, _ = rank(similarity=similarity.t(), q_pids=gids, g_pids=qids, max_rank=10, get_mAP=True)
  94. i2t_cmc, i2t_mAP, i2t_mINP = i2t_cmc.numpy(), i2t_mAP.numpy(), i2t_mINP.numpy()
  95. table.add_row(['i2t', i2t_cmc[0], i2t_cmc[4], i2t_cmc[9], i2t_mAP, i2t_mINP])
  96. # table.float_format = '.4'
  97. table.custom_format["R1"] = lambda f, v: f"{v:.3f}"
  98. table.custom_format["R5"] = lambda f, v: f"{v:.3f}"
  99. table.custom_format["R10"] = lambda f, v: f"{v:.3f}"
  100. table.custom_format["mAP"] = lambda f, v: f"{v:.3f}"
  101. table.custom_format["mINP"] = lambda f, v: f"{v:.3f}"
  102. self.logger.info('\n' + str(table))
  103. return t2i_cmc[0]