Browse Source

refactor(eval): 重构 evaluation 过程【转移至服务器运行】

- 新增 dist_collect 函数,用于收集所有 GPU 上的 tensor
- 修改 Evaluator 类中 _compute_embedding 方法,使用新的 dist_collect 函数
- 在 train_one_epoch 函数中添加打印语句,用于调试
- 在 validate_reid 函数中添加 dist.barrier(),同步所有进程
Yijun Fu 1 month ago
parent
commit
e9441ab3ef
2 changed files with 38 additions and 8 deletions
  1. 8 1
      main_group_vit.py
  2. 30 7
      utils/metrics.py

+ 8 - 1
main_group_vit.py

@@ -266,6 +266,7 @@ def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler):
     start = time.time()
     end = time.time()
     for idx, samples in enumerate(data_loader):
+        print('\n\n1\n\n')
 
         batch_size = config.data.batch_size
 
@@ -311,6 +312,7 @@ def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler):
             lr_scheduler.step_update(epoch * num_steps + idx)
 
         torch.cuda.synchronize()
+        print('\n\n2\n\n')
 
         loss_meter.update(loss.item(), batch_size)
         for loss_name in log_vars:
@@ -318,8 +320,10 @@ def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler):
         norm_meter.update(grad_norm)
         batch_time.update(time.time() - end)
         end = time.time()
+        print('\n\n3\n\n')
 
         if idx % config.print_freq == 0:
+            print('\n\n4\n\n')
             lr = optimizer.param_groups[0]['lr']
             memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
             etas = batch_time.avg * (num_steps - idx)
@@ -331,12 +335,14 @@ def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler):
                         f'{log_vars_str}\t'
                         f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
                         f'mem {memory_used:.0f}MB')
+            print('\n\n5\n\n')
             if wandb is not None:
                 log_stat = {f'iter/train_{n}': m.avg for n, m in log_vars_meters.items()}
                 log_stat['iter/train_total_loss'] = loss_meter.avg
                 log_stat['iter/learning_rate'] = lr
                 wandb.log(log_stat)
 
+    print('\n\n6\n\n')
     epoch_time = time.time() - start
     logger.info(f'EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}')
     result_dict = dict(total_loss=loss_meter.avg)
@@ -477,6 +483,7 @@ def validate_reid(cfg, img_loader, txt_loader, model):
     # rank1_result = metric[0]['Rank-1'] * 100
 
     torch.cuda.empty_cache()
+    dist.barrier()
 
     return rank1
 
@@ -542,4 +549,4 @@ def main():
 
 
 if __name__ == '__main__':
-    main()
+    main()

+ 30 - 7
utils/metrics.py

@@ -4,6 +4,21 @@ import numpy as np
 import os
 import torch.nn.functional as F
 import logging
+import diffdist.functional as diff_dist
+import torch.distributed as dist
+
+
+def dist_collect(x):
+    """ collect all tensor from all GPUs
+    args:
+        x: shape (mini_batch, ...)
+    returns:
+        shape (mini_batch * num_gpu, ...)
+    """
+    x = x.contiguous()
+    out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
+    out_list = diff_dist.all_gather(out_list, x)
+    return torch.cat(out_list, dim=0).contiguous()
 
 
 def rank(similarity, q_pids, g_pids, max_rank=10, get_mAP=True):
@@ -43,7 +58,7 @@ class Evaluator():
     def __init__(self, img_loader, txt_loader):
         self.img_loader = img_loader # gallery
         self.txt_loader = txt_loader # query
-        self.logger = logging.getLogger("IRRA.eval")
+        self.logger = logging.getLogger("GroupViT_irra.eval")
 
     def _compute_embedding(self, model):
         model = model.eval()
@@ -52,13 +67,17 @@ class Evaluator():
         qids, gids, qfeats, gfeats = [], [], [], []
         # text
         for pid, caption in self.txt_loader:
-            print('pid', pid.shape[0])
-            print('caption: ', caption.shape[0])
+            # print('pid', pid.shape[0])
+            # print('caption: ', caption.shape[0])
             caption = caption.to(device)
             with torch.no_grad():
-                text_feat = model.encode_text(caption)
+                text_outs = model.encode_text(caption, as_dict=True)
+                # [B, C]
+                text_x = text_outs['text_x']
+                text_x = F.normalize(text_x, dim=-1)
+                # text_feat = text_x @ dist_collect(image_x).t()
             qids.append(pid.view(-1)) # flatten 
-            qfeats.append(text_feat)
+            qfeats.append(text_x)
         qids = torch.cat(qids, 0)
         qfeats = torch.cat(qfeats, 0)
 
@@ -66,9 +85,13 @@ class Evaluator():
         for pid, img in self.img_loader:
             img = img.to(device)
             with torch.no_grad():
-                img_feat = model.encode_image(img)
+                image_outs = model.encode_image(img, as_dict=True)
+                # [B, C]
+                image_x = image_outs['image_x']
+                image_x = F.normalize(image_x, dim=-1)
+                # img_feat = image_x @ dist_collect(text_x).t()
             gids.append(pid.view(-1)) # flatten 
-            gfeats.append(img_feat)
+            gfeats.append(image_x)
         gids = torch.cat(gids, 0)
         gfeats = torch.cat(gfeats, 0)