Pārlūkot izejas kodu

feat(evaluate): 更新 reid 评估方法

- 修改默认配置文件,启用 reid 任务并注释掉 cls 和 seg 任务
- 重构 validate_reid 函数,支持多 GPU 评估
- 更新 train 函数中的 reid 评估逻辑
- 优化日志输出,显示 Rank-1 结果
Yijun Fu 1 mēnesi atpakaļ
vecāks
revīzija
8805f2ac02
2 mainītis faili ar 44 papildinājumiem un 10 dzēšanām
  1. 4 4
      configs/default.yml
  2. 40 6
      main_group_vit.py

+ 4 - 4
configs/default.yml

@@ -96,9 +96,9 @@ evaluate:
   eval_only: false
   eval_freq: 1
   task:
-    - cls
-    - seg
-    - retrieval
+    # - cls
+    # - seg
+    - reid
   cls:
     save_best: true
     template: subset
@@ -107,7 +107,7 @@ evaluate:
     cfg: segmentation/configs/_base_/datasets/pascal_voc12.py
     template: simple
     opts: []
-  retrieval:
+  reid:
     save_best: true
     template: simple
     opts: []

+ 40 - 6
main_group_vit.py

@@ -173,8 +173,10 @@ def train(cfg):
             miou = validate_seg(cfg, data_loader_seg, model)
             logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
         if 'reid' in cfg.evaluate.task:
-            mrank1 = validate_reid(cfg, data_loader_reid, model)
-            logger.info(f'Rank1 of the network on the {len(data_loader_reid)} test images: {mrank1:.2f}%')
+            # mrank1 = validate_reid(cfg, data_loader_reid, model)
+            mrank1 = validate_reid(cfg, val_img_loader, val_txt_loader, model)
+            # logger.info(f'Rank1 of the network on the {len(data_loader_reid)} test images: {mrank1:.2f}%')
+            logger.info(f'Rank1 of the network on the {len(val_img_loader)} test images: {mrank1:.2f}%')
         if cfg.evaluate.eval_only:
             return
 
@@ -215,8 +217,8 @@ def train(cfg):
                 max_miou = max_metrics['max_miou']
                 logger.info(f'Max mIoU: {max_miou:.2f}%')
             if 'reid' in cfg.evaluate.task:
-                mrank1 = validate_reid(cfg, data_loader_reid, model)
-                logger.info(f'mRank1 of the network on the {len(data_loader_reid)} test images: {mrank1:.2f}%')
+                mrank1 = validate_reid(cfg, val_img_loader, val_txt_loader, model)
+                logger.info(f'mRank1 of the network on the {len(val_img_loader)} test images: {mrank1:.2f}%')
                 max_metrics['max_rank1'] = max(max_metrics['max_rank1'], mrank1)
                 if cfg.evaluate.reid.save_best and dist.get_rank() == 0 and mrank1 > max_rank1:
                     save_checkpoint(
@@ -441,8 +443,40 @@ def validate_seg(config, data_loader, model):
 
 
 @torch.no_grad()
-def validate_reid(config, data_loader, model):
-    print()
+def validate_reid(cfg, img_loader, txt_loader, model):
+    logger = get_logger()
+    dist.barrier()
+    model.eval()
+
+    if hasattr(model, 'module'):
+        model_without_ddp = model.module
+    else:
+        model_without_ddp = model
+
+    reid_model = build_reid_inference(model_without_ddp, img_loader, txt_loader, cfg.evaluate.reid)
+
+    mmddp_model = MMDistributedDataParallel(
+        reid_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
+    mmddp_model.eval()
+    results = multi_gpu_test(
+        model=mmddp_model,
+        data_loader=img_loader,
+        tmpdir=None,
+        gpu_collect=True,
+        efficient_test=False,
+        pre_eval=True,
+        format_only=False)
+
+    if dist.get_rank() == 0:
+        metric = [img_loader.dataset.evaluate(results, metric='Rank-1')]
+    else:
+        metric = [None]
+    dist.broadcast_object_list(metric)
+    rank1_result = metric[0]['Rank-1'] * 100
+
+    torch.cuda.empty_cache()
+
+    return rank1_result
 
 
 def main():