|
@@ -4,6 +4,21 @@ import numpy as np
|
|
import os
|
|
import os
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
import logging
|
|
import logging
|
|
|
|
+import diffdist.functional as diff_dist
|
|
|
|
+import torch.distributed as dist
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def dist_collect(x):
|
|
|
|
+ """ collect all tensor from all GPUs
|
|
|
|
+ args:
|
|
|
|
+ x: shape (mini_batch, ...)
|
|
|
|
+ returns:
|
|
|
|
+ shape (mini_batch * num_gpu, ...)
|
|
|
|
+ """
|
|
|
|
+ x = x.contiguous()
|
|
|
|
+ out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
|
|
|
|
+ out_list = diff_dist.all_gather(out_list, x)
|
|
|
|
+ return torch.cat(out_list, dim=0).contiguous()
|
|
|
|
|
|
|
|
|
|
def rank(similarity, q_pids, g_pids, max_rank=10, get_mAP=True):
|
|
def rank(similarity, q_pids, g_pids, max_rank=10, get_mAP=True):
|
|
@@ -43,7 +58,7 @@ class Evaluator():
|
|
def __init__(self, img_loader, txt_loader):
|
|
def __init__(self, img_loader, txt_loader):
|
|
self.img_loader = img_loader # gallery
|
|
self.img_loader = img_loader # gallery
|
|
self.txt_loader = txt_loader # query
|
|
self.txt_loader = txt_loader # query
|
|
- self.logger = logging.getLogger("IRRA.eval")
|
|
|
|
|
|
+ self.logger = logging.getLogger("GroupViT_irra.eval")
|
|
|
|
|
|
def _compute_embedding(self, model):
|
|
def _compute_embedding(self, model):
|
|
model = model.eval()
|
|
model = model.eval()
|
|
@@ -52,13 +67,17 @@ class Evaluator():
|
|
qids, gids, qfeats, gfeats = [], [], [], []
|
|
qids, gids, qfeats, gfeats = [], [], [], []
|
|
# text
|
|
# text
|
|
for pid, caption in self.txt_loader:
|
|
for pid, caption in self.txt_loader:
|
|
- print('pid', pid.shape[0])
|
|
|
|
- print('caption: ', caption.shape[0])
|
|
|
|
|
|
+ # print('pid', pid.shape[0])
|
|
|
|
+ # print('caption: ', caption.shape[0])
|
|
caption = caption.to(device)
|
|
caption = caption.to(device)
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
- text_feat = model.encode_text(caption)
|
|
|
|
|
|
+ text_outs = model.encode_text(caption, as_dict=True)
|
|
|
|
+ # [B, C]
|
|
|
|
+ text_x = text_outs['text_x']
|
|
|
|
+ text_x = F.normalize(text_x, dim=-1)
|
|
|
|
+ # text_feat = text_x @ dist_collect(image_x).t()
|
|
qids.append(pid.view(-1)) # flatten
|
|
qids.append(pid.view(-1)) # flatten
|
|
- qfeats.append(text_feat)
|
|
|
|
|
|
+ qfeats.append(text_x)
|
|
qids = torch.cat(qids, 0)
|
|
qids = torch.cat(qids, 0)
|
|
qfeats = torch.cat(qfeats, 0)
|
|
qfeats = torch.cat(qfeats, 0)
|
|
|
|
|
|
@@ -66,9 +85,13 @@ class Evaluator():
|
|
for pid, img in self.img_loader:
|
|
for pid, img in self.img_loader:
|
|
img = img.to(device)
|
|
img = img.to(device)
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
- img_feat = model.encode_image(img)
|
|
|
|
|
|
+ image_outs = model.encode_image(img, as_dict=True)
|
|
|
|
+ # [B, C]
|
|
|
|
+ image_x = image_outs['image_x']
|
|
|
|
+ image_x = F.normalize(image_x, dim=-1)
|
|
|
|
+ # img_feat = image_x @ dist_collect(text_x).t()
|
|
gids.append(pid.view(-1)) # flatten
|
|
gids.append(pid.view(-1)) # flatten
|
|
- gfeats.append(img_feat)
|
|
|
|
|
|
+ gfeats.append(image_x)
|
|
gids = torch.cat(gids, 0)
|
|
gids = torch.cat(gids, 0)
|
|
gfeats = torch.cat(gfeats, 0)
|
|
gfeats = torch.cat(gfeats, 0)
|
|
|
|
|