Ver Fonte

代码能运行,但是效果不对

refactor(reid): 重构行人重识别验证流程

- 移除 MMDistributedDataParallel 和 multi_gpu_test 的使用
- 引入 Evaluator 类来执行重识别验证
- 简化代码结构,提高可读性和维护性
Yijun Fu há 1 mês atrás
pai
commit
90ddc4c5f7
2 ficheiros alterados com 127 adições e 22 exclusões
  1. 24 22
      main_group_vit.py
  2. 103 0
      utils/metrics.py

+ 24 - 22
main_group_vit.py

@@ -51,6 +51,7 @@ from utils import (auto_resume_helper, build_dataset_class_tokens, build_optimiz
                    get_config, get_grad_norm, get_logger, load_checkpoint, parse_losses, reduce_tensor, save_checkpoint)
 
 from tools.cfg2arg import cfg2arg
+from utils.metrics import Evaluator
 
 
 try:
@@ -446,37 +447,38 @@ def validate_seg(config, data_loader, model):
 def validate_reid(cfg, img_loader, txt_loader, model):
     logger = get_logger()
     dist.barrier()
-    model.eval()
+    # model.eval()
+    evaluator = Evaluator(img_loader, txt_loader)
 
     if hasattr(model, 'module'):
         model_without_ddp = model.module
     else:
         model_without_ddp = model
 
-    reid_model = build_reid_inference(model_without_ddp, img_loader, txt_loader, cfg.evaluate.reid)
-
-    mmddp_model = MMDistributedDataParallel(
-        reid_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
-    mmddp_model.eval()
-    results = multi_gpu_test(
-        model=mmddp_model,
-        data_loader=img_loader,
-        tmpdir=None,
-        gpu_collect=True,
-        efficient_test=False,
-        pre_eval=True,
-        format_only=False)
-
-    if dist.get_rank() == 0:
-        metric = [img_loader.dataset.evaluate(results, metric='Rank-1')]
-    else:
-        metric = [None]
-    dist.broadcast_object_list(metric)
-    rank1_result = metric[0]['Rank-1'] * 100
+    # reid_model = build_reid_inference(model_without_ddp, img_loader, txt_loader, cfg.evaluate.reid)
+
+    # mmddp_model = MMDistributedDataParallel(
+    #     reid_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
+    rank1 = evaluator.eval(model_without_ddp.eval())
+    # results = multi_gpu_test(
+    #     model=mmddp_model,
+    #     data_loader=img_loader,
+    #     tmpdir=None,
+    #     gpu_collect=True,
+    #     efficient_test=False,
+    #     pre_eval=True,
+    #     format_only=False)
+
+    # if dist.get_rank() == 0:
+    #     metric = [img_loader.dataset.evaluate(results, metric='Rank-1')]
+    # else:
+    #     metric = [None]
+    # dist.broadcast_object_list(metric)
+    # rank1_result = metric[0]['Rank-1'] * 100
 
     torch.cuda.empty_cache()
 
-    return rank1_result
+    return rank1
 
 
 def main():

+ 103 - 0
utils/metrics.py

@@ -0,0 +1,103 @@
+from prettytable import PrettyTable
+import torch
+import numpy as np
+import os
+import torch.nn.functional as F
+import logging
+
+
+def rank(similarity, q_pids, g_pids, max_rank=10, get_mAP=True):
+    if get_mAP:
+        indices = torch.argsort(similarity, dim=1, descending=True)
+    else:
+        # acclerate sort with topk
+        _, indices = torch.topk(
+            similarity, k=max_rank, dim=1, largest=True, sorted=True
+        )  # q * topk
+    pred_labels = g_pids[indices.cpu()]  # q * k
+    matches = pred_labels.eq(q_pids.view(-1, 1))  # q * k
+
+    all_cmc = matches[:, :max_rank].cumsum(1) # cumulative sum
+    all_cmc[all_cmc > 1] = 1
+    all_cmc = all_cmc.float().mean(0) * 100
+    # all_cmc = all_cmc[topk - 1]
+
+    if not get_mAP:
+        return all_cmc, indices
+
+    num_rel = matches.sum(1)  # q
+    tmp_cmc = matches.cumsum(1)  # q * k
+
+    inp = [tmp_cmc[i][match_row.nonzero()[-1]] / (match_row.nonzero()[-1] + 1.) for i, match_row in enumerate(matches)]
+    mINP = torch.cat(inp).mean() * 100
+
+    tmp_cmc = [tmp_cmc[:, i] / (i + 1.0) for i in range(tmp_cmc.shape[1])]
+    tmp_cmc = torch.stack(tmp_cmc, 1) * matches
+    AP = tmp_cmc.sum(1) / num_rel  # q
+    mAP = AP.mean() * 100
+
+    return all_cmc, mAP, mINP, indices
+
+
+class Evaluator():
+    def __init__(self, img_loader, txt_loader):
+        self.img_loader = img_loader # gallery
+        self.txt_loader = txt_loader # query
+        self.logger = logging.getLogger("IRRA.eval")
+
+    def _compute_embedding(self, model):
+        model = model.eval()
+        device = next(model.parameters()).device
+
+        qids, gids, qfeats, gfeats = [], [], [], []
+        # text
+        for pid, caption in self.txt_loader:
+            print('pid', pid.shape[0])
+            print('caption: ', caption.shape[0])
+            caption = caption.to(device)
+            with torch.no_grad():
+                text_feat = model.encode_text(caption)
+            qids.append(pid.view(-1)) # flatten 
+            qfeats.append(text_feat)
+        qids = torch.cat(qids, 0)
+        qfeats = torch.cat(qfeats, 0)
+
+        # image
+        for pid, img in self.img_loader:
+            img = img.to(device)
+            with torch.no_grad():
+                img_feat = model.encode_image(img)
+            gids.append(pid.view(-1)) # flatten 
+            gfeats.append(img_feat)
+        gids = torch.cat(gids, 0)
+        gfeats = torch.cat(gfeats, 0)
+
+        return qfeats, gfeats, qids, gids
+    
+    def eval(self, model, i2t_metric=False):
+
+        qfeats, gfeats, qids, gids = self._compute_embedding(model)
+
+        qfeats = F.normalize(qfeats, p=2, dim=1) # text features
+        gfeats = F.normalize(gfeats, p=2, dim=1) # image features
+
+        similarity = qfeats @ gfeats.t()
+
+        t2i_cmc, t2i_mAP, t2i_mINP, _ = rank(similarity=similarity, q_pids=qids, g_pids=gids, max_rank=10, get_mAP=True)
+        t2i_cmc, t2i_mAP, t2i_mINP = t2i_cmc.numpy(), t2i_mAP.numpy(), t2i_mINP.numpy()
+        table = PrettyTable(["task", "R1", "R5", "R10", "mAP", "mINP"])
+        table.add_row(['t2i', t2i_cmc[0], t2i_cmc[4], t2i_cmc[9], t2i_mAP, t2i_mINP])
+
+        if i2t_metric:
+            i2t_cmc, i2t_mAP, i2t_mINP, _ = rank(similarity=similarity.t(), q_pids=gids, g_pids=qids, max_rank=10, get_mAP=True)
+            i2t_cmc, i2t_mAP, i2t_mINP = i2t_cmc.numpy(), i2t_mAP.numpy(), i2t_mINP.numpy()
+            table.add_row(['i2t', i2t_cmc[0], i2t_cmc[4], i2t_cmc[9], i2t_mAP, i2t_mINP])
+        # table.float_format = '.4'
+        table.custom_format["R1"] = lambda f, v: f"{v:.3f}"
+        table.custom_format["R5"] = lambda f, v: f"{v:.3f}"
+        table.custom_format["R10"] = lambda f, v: f"{v:.3f}"
+        table.custom_format["mAP"] = lambda f, v: f"{v:.3f}"
+        table.custom_format["mINP"] = lambda f, v: f"{v:.3f}"
+        self.logger.info('\n' + str(table))
+        
+        return t2i_cmc[0]