multi_label_contrastive.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. # Modified by Jilan Xu
  14. # -------------------------------------------------------------------------
  15. import diffdist.functional as diff_dist
  16. import numpy as np
  17. import torch
  18. import torch.distributed as dist
  19. import torch.nn as nn
  20. import torch.nn.functional as F
  21. from einops import rearrange, repeat
  22. from timm.loss import SoftTargetCrossEntropy
  23. from random import choice
  24. from .builder import MODELS
  25. from .misc import Result
  26. from .losses import HungarianMatcher, dice_loss
  27. from ipdb import set_trace
  28. import torchvision.ops.roi_pool as roi_pool
  29. import cv2
  30. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  31. from .group_vit import CrossAttnBlock, AssignAttention, AttnBlock
  32. def dist_collect(x):
  33. """ collect all tensor from all GPUs
  34. args:
  35. x: shape (mini_batch, ...)
  36. returns:
  37. shape (mini_batch * num_gpu, ...)
  38. """
  39. x = x.contiguous()
  40. out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
  41. out_list = diff_dist.all_gather(out_list, x)
  42. return torch.cat(out_list, dim=0).contiguous()
  43. class ProjectMLP(nn.Module):
  44. def __init__(self, in_dim=256, inner_dim=4096, out_dim=256, num_layers=2):
  45. super(ProjectMLP, self).__init__()
  46. # hidden layers
  47. linear_hidden = []
  48. for i in range(num_layers - 1):
  49. linear_hidden.append(nn.Conv1d(in_dim if i == 0 else inner_dim, inner_dim, kernel_size=1))
  50. linear_hidden.append(nn.BatchNorm1d(inner_dim))
  51. linear_hidden.append(nn.ReLU(inplace=True))
  52. self.linear_hidden = nn.Sequential(*linear_hidden)
  53. self.linear_out = nn.Conv1d(
  54. in_dim if num_layers == 1 else inner_dim, out_dim, kernel_size=1) if num_layers >= 1 else nn.Identity()
  55. def forward(self, x):
  56. """
  57. Args:
  58. x (torch.Tensor): output of transformers, shape [B, L, C]
  59. Returns:
  60. """
  61. assert x.ndim in [2, 3], x.ndim
  62. add_dim = False
  63. if x.ndim == 2:
  64. # [B, C] -> [B, L, C]
  65. x = x.unsqueeze(1)
  66. add_dim = True
  67. x = rearrange(x, 'b l c -> b c l')
  68. x = self.linear_hidden(x)
  69. x = self.linear_out(x)
  70. x = rearrange(x, 'b c l -> b l c')
  71. if add_dim:
  72. x = x.squeeze(1)
  73. return x
  74. class MultimodalGroupingBlock(nn.Module):
  75. """Grouping Block to group similar segments together.
  76. Args:
  77. dim (int): Dimension of the input.
  78. out_dim (int): Dimension of the output.
  79. num_heads (int): Number of heads in the grouping attention.
  80. num_output_group (int): Number of output groups.
  81. norm_layer (nn.Module): Normalization layer to use.
  82. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  83. hard (bool): Whether to use hard or soft assignment. Default: True
  84. gumbel (bool): Whether to use gumbel softmax. Default: True
  85. sum_assign (bool): Whether to sum assignment or average. Default: False
  86. assign_eps (float): Epsilon to avoid divide by zero. Default: 1
  87. gum_tau (float): Temperature for gumbel softmax. Default: 1
  88. """
  89. def __init__(self,
  90. *,
  91. dim,
  92. out_dim,
  93. num_heads,
  94. norm_layer,
  95. mlp_ratio=(0.5, 4.0),
  96. hard=True,
  97. gumbel=True,
  98. sum_assign=False,
  99. assign_eps=1.,
  100. gumbel_tau=1.,
  101. attn_drop=0.,
  102. ):
  103. super(MultimodalGroupingBlock, self).__init__()
  104. self.dim = dim
  105. self.hard = hard
  106. self.gumbel = gumbel
  107. self.sum_assign = sum_assign
  108. # norm on group_tokens
  109. self.norm_tokens = norm_layer(dim)
  110. tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
  111. # norm on x
  112. self.norm_x = norm_layer(dim)
  113. # self.visual_attn = AttnBlock(
  114. # dim=dim, num_heads=num_heads, mlp_ratio=4, qkv_bias=True, norm_layer=norm_layer )
  115. self.pre_assign_attn = CrossAttnBlock(
  116. dim=dim, num_heads=num_heads, mlp_ratio=4, qkv_bias=True, norm_layer=norm_layer, post_norm=True)
  117. self.post_attn = AttnBlock(
  118. dim=dim, num_heads=num_heads, mlp_ratio=4, qkv_bias=True, norm_layer=norm_layer )
  119. self.assign = AssignAttention(
  120. dim=dim,
  121. num_heads=1,
  122. qkv_bias=True,
  123. hard=hard,
  124. gumbel=gumbel,
  125. gumbel_tau=gumbel_tau,
  126. sum_assign=sum_assign,
  127. assign_eps=assign_eps,
  128. attn_drop=attn_drop,
  129. )
  130. self.norm_new_x = norm_layer(dim)
  131. def forward(self, ans_tokens, visual_tokens, text_tokens, entity_masks=None, question_masks=None, return_attn=False):
  132. """
  133. Args:
  134. x (torch.Tensor): group_tokens, [B, k, C]
  135. group_tokens (torch.Tensor): word tokens, [B, L, C]
  136. return_attn (bool): whether to return attention map
  137. Returns:
  138. new_x (torch.Tensor): [B, S_2, C], S_2 is the new number of
  139. group tokens
  140. """
  141. # [B, K, C], self-attention
  142. # visual_tokens = self.visual_attn(visual_tokens)
  143. text_tokens = self.norm_tokens(text_tokens)
  144. visual_tokens = self.norm_x(visual_tokens)
  145. # [B, L, C], cross attention
  146. projected_text_tokens = self.pre_assign_attn(text_tokens, visual_tokens)
  147. ### mask needs to be [b, 1, 77, 1] to match [b, nh, 77, k]
  148. # projected_text_tokens = text_tokens
  149. # new_x, attn_dict = self.assign(projected_text_tokens, visual_tokens, return_attn=return_attn, mask=question_masks)
  150. if ans_tokens is None:
  151. ans_temp = projected_text_tokens
  152. else:
  153. ans_temp = ans_tokens + projected_text_tokens
  154. ############## self-attn only ###################
  155. if question_masks is not None:
  156. new_x = self.post_attn(ans_temp, mask=question_masks)
  157. else:
  158. new_x = self.post_attn(ans_temp)
  159. new_x += projected_text_tokens
  160. new_x = self.norm_new_x(new_x)
  161. return new_x
  162. class MultimodalGroupingNetwork(nn.Module):
  163. """Grouping Block to group similar segments together.
  164. Args:
  165. dim (int): Dimension of the input.
  166. out_dim (int): Dimension of the output.
  167. num_heads (int): Number of heads in the grouping attention.
  168. norm_layer (nn.Module): Normalization layer to use.
  169. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  170. hard (bool): Whether to use hard or soft assignment. Default: True
  171. gumbel (bool): Whether to use gumbel softmax. Default: True
  172. sum_assign (bool): Whether to sum assignment or average. Default: False
  173. assign_eps (float): Epsilon to avoid divide by zero. Default: 1
  174. gum_tau (float): Temperature for gumbel softmax. Default: 1
  175. """
  176. def __init__(self,
  177. *,
  178. dim,
  179. out_dim,
  180. num_heads,
  181. norm_layer,
  182. mlp_ratio=(0.5, 4.0),
  183. hard=True,
  184. gumbel=True,
  185. sum_assign=False,
  186. assign_eps=1.,
  187. gumbel_tau=1.,
  188. attn_drop=0.,
  189. num_layers=1,
  190. ):
  191. super(MultimodalGroupingNetwork, self).__init__()
  192. self.num_layers = num_layers
  193. self.blocks = nn.ModuleList([
  194. MultimodalGroupingBlock(
  195. dim=dim,
  196. out_dim=out_dim,
  197. num_heads=num_heads,
  198. norm_layer=norm_layer,
  199. mlp_ratio=mlp_ratio,
  200. hard=hard,
  201. gumbel=gumbel,
  202. sum_assign=sum_assign,
  203. assign_eps=assign_eps,
  204. gumbel_tau=gumbel_tau,
  205. attn_drop=attn_drop,
  206. ) for i in range(num_layers)
  207. ])
  208. def forward(self, visual_tokens, text_tokens, entity_masks=None, question_masks=None, return_attn=False, return_feat=False):
  209. """
  210. Args:
  211. x (torch.Tensor): group_tokens, [B, k, C]
  212. group_tokens (torch.Tensor): word tokens, [B, L, C]
  213. return_attn (bool): whether to return attention map
  214. Returns:
  215. new_x (torch.Tensor): [B, S_2, C], S_2 is the new number of
  216. group tokens
  217. 1. norm
  218. 2. cross-attn
  219. 3. self-attn
  220. """
  221. ans_text = None
  222. for i, blk in enumerate(self.blocks):
  223. ans_text = blk(ans_text, visual_tokens, text_tokens, entity_masks, question_masks, return_attn)
  224. if return_feat is True: #[B, L, d_t]
  225. return ans_text
  226. answer = ans_text[:, 0]
  227. return answer
  228. @MODELS.register_module()
  229. class MultiLabelContrastive(nn.Module):
  230. def __init__(self,
  231. img_encoder,
  232. text_encoder,
  233. output_dim=256,
  234. contrast_temperature=0.07,
  235. proj_num_layers=2,
  236. multi_label=0,
  237. share_temperature=False,
  238. multi_label_loss_weight=1.0,
  239. use_entityloss=False,
  240. entity_weight=1.0,
  241. cross_layers=1,
  242. use_maskloss=False,
  243. maskloss_weight=0.1,
  244. num_deep_stages=1,
  245. cost_type='L2',
  246. cross_threshold=0.6,
  247. topmask_ratio=1.0,
  248. dual_dice=False,
  249. group_ratio=0.5,
  250. ):
  251. super().__init__()
  252. self.img_encoder = MODELS.build(img_encoder)
  253. self.text_encoder = MODELS.build(text_encoder)
  254. self.img_encoder_type = img_encoder['type']
  255. self.text_encoder_type = text_encoder['type']
  256. # add
  257. print('self image encoder: ', img_encoder)
  258. print('self text encoder:', text_encoder)
  259. self.contrast_temperature = contrast_temperature
  260. self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
  261. self.cross_entropy = nn.CrossEntropyLoss(ignore_index=-1)
  262. self.binary_cross_entropy = nn.BCELoss()
  263. self.binary_cross_entropy_with_logits = nn.BCEWithLogitsLoss()
  264. self.soft_cross_entropy = SoftTargetCrossEntropy()
  265. self.mse_loss = nn.MSELoss()
  266. self.proj_num_layers = proj_num_layers
  267. self.multi_label = multi_label
  268. if proj_num_layers > 0:
  269. # if proj_num_layers > 0 and self.use_clip_visual is False:
  270. self.img_projector = ProjectMLP(
  271. in_dim=self.img_encoder.width, num_layers=proj_num_layers, out_dim=output_dim)
  272. self.text_projector = ProjectMLP(
  273. in_dim=self.text_encoder.width, num_layers=proj_num_layers, out_dim=output_dim)
  274. self.img_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.img_projector)
  275. self.text_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.text_projector)
  276. elif proj_num_layers == -1:
  277. self.img_projector = nn.Linear(self.img_encoder.width, self.text_encoder.width)
  278. self.text_projector = nn.Identity()
  279. else:
  280. self.img_projector = nn.Identity()
  281. self.text_projector = nn.Identity()
  282. self.share_temperature = share_temperature
  283. if self.with_multi_label and not self.share_temperature:
  284. self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
  285. self.multi_label_loss_weight = multi_label_loss_weight
  286. ### for masked entity loss ###
  287. self.use_entityloss = use_entityloss
  288. self.entity_weight = entity_weight
  289. self.cross_layers = cross_layers
  290. if self.use_entityloss:
  291. min_width = min(self.img_encoder.width, self.text_encoder.width)
  292. max_width = max(self.img_encoder.width, self.text_encoder.width)
  293. self.align_proj_img = nn.Linear(max_width, min_width) if self.img_encoder.width > self.text_encoder.width else nn.Identity()
  294. self.align_proj_text = nn.Linear(max_width, min_width) if self.text_encoder.width > self.img_encoder.width else nn.Identity()
  295. ### similar to transformer decoder ###
  296. self.multimodal_groupingblock = MultimodalGroupingNetwork(
  297. dim=min_width,
  298. out_dim=min_width,
  299. num_heads=8,
  300. norm_layer=nn.LayerNorm,
  301. hard=False,
  302. gumbel=False,
  303. num_layers=cross_layers,
  304. )
  305. self.bridge_projector = ProjectMLP(
  306. in_dim=min_width, num_layers=proj_num_layers, out_dim=output_dim)
  307. ### for mask loss ###
  308. self.use_maskloss = use_maskloss
  309. self.maskloss_weight = maskloss_weight
  310. self.cross_threshold = cross_threshold
  311. self.topmask_ratio = topmask_ratio
  312. self.dual_dice = dual_dice
  313. self.group_ratio = group_ratio
  314. if self.use_maskloss:
  315. self.num_deep_stages = num_deep_stages
  316. self.logit_scale_mask = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
  317. self.img_encoder_momentum = MODELS.build(img_encoder)
  318. self.q_projector = nn.Identity()
  319. self.k_projector = nn.Identity()
  320. self.q_projector_momentum = nn.Identity()
  321. self.k_projector_momentum = nn.Identity()
  322. ## set momentum branch offline
  323. for p in self.img_encoder_momentum.parameters():
  324. p.requires_grad = False
  325. self.matcher = HungarianMatcher(cost_type=cost_type)
  326. def mask_loss(self, mask1, mask2, threshold, imgtokens=None, text=None, indicator='none'):
  327. # set_trace()
  328. bs = mask1.size(0)
  329. num_masks = mask1.size(1)
  330. ################# hungarian matching #######################################
  331. #[b, k, hw], make the masks exclusive with softmax???
  332. ############# Note, we keep the original mask, while using the normed mask to compute matching ########
  333. mask1 = torch.flatten(mask1, 2).float()
  334. mask2 = torch.flatten(mask2, 2).float()
  335. mask1_norm = F.normalize(mask1, dim=-1)
  336. mask2_norm = F.normalize(mask2, dim=-1)
  337. idx1, idx2 = self.matcher(mask1_norm, mask2_norm)
  338. mask1 = mask1[torch.arange(bs).unsqueeze(1), idx1]
  339. mask2 = mask2[torch.arange(bs).unsqueeze(1), idx2]
  340. ################## norm and contrastive loss ################################
  341. #[b, k, hw]
  342. ################# BCE loss ##################################################
  343. ### hard-thresholding ###
  344. def min_max_norm(x):
  345. x_max = torch.max(x, dim=-1, keepdim=True)[0]
  346. x_min = torch.min(x, dim=-1, keepdim=True)[0]
  347. return (x - x_min) / (x_max - x_min)
  348. ################ THIS IS PERHAPS IMPORTANT HERE ##############
  349. mask2 = mask2.sigmoid()
  350. # mask2 = F.softmax(mask2, dim=1)
  351. # mask2 = min_max_norm(mask2)
  352. # mask2 = F.normalize(mask2)
  353. mask2_pseudo = mask2
  354. mask2_pseudo = rearrange(mask2_pseudo, 'b k d -> (b k) d')
  355. thres_onehot = torch.max(mask2_pseudo, dim=-1, keepdim=True)[0] * threshold
  356. mask2_onehot = mask2_pseudo - thres_onehot
  357. mask2_onehot[mask2_onehot >= 0] = 1.0
  358. mask2_onehot[mask2_onehot < 0] = 0.0
  359. mask2_onehot = rearrange(mask2_onehot, '(b k) d -> b k d', k=num_masks)
  360. # self.draw_attn(rearrange(mask1, 'b k (h w) -> b k h w', k=num_masks, h=224), 'before_sigmoid')
  361. # set_trace()
  362. # mask1 = F.softmax(mask1, dim=1)
  363. # mask1 = torch.sigmoid(mask1)
  364. mask1 = min_max_norm(mask1)
  365. ####### select topk mask for contrast w.r.t ratio #######
  366. topk_mask = None
  367. # if self.topmask_ratio < 1.0:
  368. # alltoken_logits = (imgtokens @ text.unsqueeze(-1)).squeeze(-1) #[bs, k]
  369. # topk_logits = torch.topk(alltoken_logits, k=int(num_masks * self.topmask_ratio))[1]
  370. # topk_mask = torch.zeros_like(alltoken_logits)
  371. # topk_mask[torch.arange(bs).unsqueeze(1), topk_logits] = 1.0
  372. # set_trace()
  373. #########################################################
  374. loss = dice_loss(mask1, mask2_onehot, topk_mask=topk_mask)
  375. return loss
  376. @property
  377. def with_multi_label(self):
  378. return self.multi_label > 0
  379. def loss(self, image_x, text_x):
  380. batch_size = image_x.shape[0]
  381. # get label globally
  382. labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device) + batch_size * dist.get_rank()
  383. image_x = F.normalize(image_x, dim=-1) #[B, C]
  384. text_x = F.normalize(text_x, dim=-1) #[B, C]
  385. logits_per_img = image_x @ dist_collect(text_x).t()
  386. logits_per_text = text_x @ dist_collect(image_x).t()
  387. logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
  388. loss_img = self.cross_entropy(logits_per_img * logit_scale, labels)
  389. loss_text = self.cross_entropy(logits_per_text * logit_scale, labels)
  390. loss = 0.5 * (loss_img + loss_text)
  391. return loss
  392. def multi_label_loss(self, image_feat, text_feat):
  393. """
  394. Args:
  395. image_feat (torch.Tensor): shape [B, L1, C]
  396. text_feat (torch.Tensor): shape [B, L2, C]
  397. Returns:
  398. """
  399. # [B, L1, C], L1 = 1
  400. image_feat = F.normalize(image_feat, dim=-1)
  401. # [B, L2, C]
  402. text_feat = F.normalize(text_feat, dim=-1)
  403. # [B, L1, L2]
  404. dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')
  405. # [B, L2, L1]
  406. dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')
  407. if self.share_temperature:
  408. logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
  409. else:
  410. logit_scale = torch.clamp(self.multi_label_logit_scale.exp(), max=100)
  411. batch = image_feat.shape[0]
  412. img_len = image_feat.shape[1]
  413. text_len = text_feat.shape[1]
  414. # [B, L1, L2]
  415. pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')
  416. # [B, L2, L1]
  417. pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')
  418. image_x = rearrange(image_feat, 'b l c -> (b l) c')
  419. text_x = rearrange(text_feat, 'b l c -> (b l) c')
  420. logits_per_img = image_x @ dist_collect(text_x).t()
  421. logits_per_text = text_x @ dist_collect(image_x).t()
  422. # get label globally
  423. # [B, L1, B, L2, W]
  424. labels_per_img = F.one_hot(
  425. torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * dist.get_rank(),
  426. num_classes=dist.get_world_size()).to(image_x.dtype)
  427. labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(
  428. torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')
  429. # [BxL1, WxBxL2]
  430. labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')
  431. # [B, L2, B, L1, W]
  432. labels_per_text = F.one_hot(
  433. torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * dist.get_rank(),
  434. num_classes=dist.get_world_size()).to(text_x.dtype)
  435. labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(
  436. torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')
  437. # [BxL2, WxBxL1]
  438. labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')
  439. loss_img = self.soft_cross_entropy(logits_per_img * logit_scale, labels_per_img)
  440. loss_text = self.soft_cross_entropy(logits_per_text * logit_scale, labels_per_text)
  441. loss = 0.5 * (loss_img + loss_text)
  442. return loss
  443. def encode_image(self, image, *, return_feat=False, as_dict=False, return_attn=False, momentum=False):
  444. outs = Result(as_dict)
  445. ### momentum branch, no gradient update ###
  446. if momentum:
  447. with torch.no_grad():
  448. img_outs = self.img_encoder_momentum(image, return_feat=return_feat, as_dict=True, return_attn=return_attn)
  449. outs.append(self.img_projector(img_outs['x']), 'image_x')
  450. if return_feat and 'feat' in img_outs:
  451. outs.append(img_outs['x'], 'image_x_before_proj')
  452. outs.append(img_outs['feat'], 'image_feat_before_proj')
  453. if return_feat:
  454. outs.append(self.img_projector(img_outs['feat']), 'image_feat')
  455. if return_attn:
  456. outs.append(img_outs['attn_dicts'], 'attn_dicts')
  457. return outs.as_return()
  458. else:
  459. ### online branch ###
  460. img_outs = self.img_encoder(image, return_feat=return_feat, as_dict=True, return_attn=return_attn)
  461. # change here
  462. outs.append(self.img_projector(img_outs['x']), 'image_x')
  463. if return_feat and 'feat' in img_outs:
  464. outs.append(img_outs['x'], 'image_x_before_proj')
  465. outs.append(img_outs['feat'], 'image_feat_before_proj')
  466. if return_feat:
  467. outs.append(self.img_projector(img_outs['feat']), 'image_feat')
  468. if return_attn:
  469. outs.append(img_outs['attn_dicts'], 'attn_dicts')
  470. return outs.as_return()
  471. def encode_text(self, text, *, as_dict=False, forward_template=False):
  472. # assert text.ndim in [2, 3], text.ndim
  473. squeeze_dim = False
  474. num_text = 1
  475. if type(text) is not dict and text.ndim == 3:
  476. num_text = text.shape[1]
  477. text = rearrange(text, 'b n l -> (b n) l', n=num_text)
  478. squeeze_dim = True
  479. outs = Result(as_dict=as_dict)
  480. # [B, C]
  481. text_outs = self.text_encoder(text)
  482. if 'all_tokens' in text_outs:
  483. all_tokens = text_outs['all_tokens'].contiguous()
  484. x = text_outs['x']
  485. text_x = self.text_projector(x)
  486. outs.append(text_x, 'text_x')
  487. outs.append(x, 'text_x_before_proj') # add transformer out
  488. outs.append(all_tokens, 'text_feat_before_proj')
  489. outs.append(self.text_projector(all_tokens), 'text_feat_after_proj')
  490. # if squeeze_dim:
  491. if (squeeze_dim or self.with_multi_label) and self.training:
  492. # text_x = rearrange(text_x, '(b n) c -> b n c', n=num_text)
  493. text_x = rearrange(text_x, '(b n) c -> b n c', n=self.multi_label + 1) ### 2 prompts and 1 caption
  494. text_multi_label_x = text_x[:, 1:]
  495. text_x = text_x[:, 0]
  496. ####### here projection !!!! #######
  497. outs.update(text_x=text_x, text_multi_label_x=text_multi_label_x)
  498. return outs.as_return()
  499. def project_and_mask(self, q, k, branch='online'):
  500. scale = self.img_encoder.width ** -0.5
  501. if branch == 'online':
  502. q = self.q_projector(q)
  503. k = self.k_projector(k)
  504. attn = q @ k.transpose(-2, -1) * scale ### no softmax for now
  505. else:
  506. with torch.no_grad():
  507. q = self.q_projector_momentum(q)
  508. k = self.k_projector_momentum(k)
  509. attn = q @ k.transpose(-2, -1) * scale ### no softmax for now
  510. return attn
  511. def forward_train(self, image, text, cross_image=None, cross_entity=None, \
  512. question=None, answer=None, entity_masks=None, question_masks=None):
  513. bs = image.size(0)
  514. losses_dict = dict()
  515. ############################################################
  516. ### Encode image and caption, calculate image-caption matching loss ###
  517. text_outs = self.encode_text(text, as_dict=True)
  518. text_x = text_outs['text_x'] # [B, C]
  519. image_outs = self.encode_image(image, as_dict=True, return_feat=True, return_attn=True)
  520. image_x = image_outs['image_x'] # [B, C]
  521. matchingloss = self.loss(image_x, text_x)
  522. losses_dict['matching'] = matchingloss
  523. ############################################################
  524. ### Encode question/answer and calculate masked entity modeling loss (if necessary) ###
  525. entityloss = .0
  526. if self.use_entityloss:
  527. visual_feat = image_outs['image_feat_before_proj'] # unprojected group token features [b, k, d_v]
  528. ### Encode questions ###
  529. question_feat = self.encode_text(question, as_dict=True)['text_feat_before_proj'] ## unprojected word tokens, [B, L, d_t]
  530. current_question_masks = question['attention_mask'] if isinstance(question, dict) else None
  531. ### Encode answer ###
  532. answer_feat = self.encode_text(answer, as_dict=True)['text_x'] # projected answer embedding, #[B, d]
  533. ### project the group feature/question feature to the common multimodal space ###
  534. visual_feat = self.align_proj_img(visual_feat)
  535. question_feat = self.align_proj_text(question_feat)
  536. ### calculate entity loss ###
  537. question_out = self.multimodal_groupingblock(visual_feat, question_feat, entity_masks=entity_masks, question_masks=current_question_masks) #[b, d_t]
  538. question_out = self.bridge_projector(question_out) #[b, d]
  539. entityloss = self.loss(question_out, answer_feat)
  540. losses_dict['entity'] = entityloss
  541. ############################################################
  542. ### Encode cross-image and calculate mask loss ###
  543. maskloss = .0
  544. if self.use_maskloss:
  545. assert cross_image is not None and cross_entity is not None
  546. image_outs3 = self.encode_image(cross_image, as_dict=True, return_feat=True, return_attn=True, momentum=True)
  547. # total_stages = len(image_outs3['attn_dicts'])
  548. attn_q = image_outs['attn_dicts'][0]['q'].squeeze(1)
  549. attn_k = image_outs['attn_dicts'][0]['k'].squeeze(1)
  550. attn_q_cross = image_outs3['attn_dicts'][0]['q'].squeeze(1)
  551. attn_k_cross = image_outs3['attn_dicts'][0]['k'].squeeze(1)
  552. attn_map3 = self.project_and_mask(attn_q_cross, attn_k_cross)
  553. attn_map_cross1 = self.project_and_mask(attn_q, attn_k_cross) # the mask to match image
  554. def compute_cross_loss(mask1, mask2, cross_entity, groups, indicator='none'):
  555. mask1 = rearrange(mask1, 'b k (h w) -> b k h w', h = 14, w = 14) # hard coded this for now, [b, h, w]
  556. mask2 = rearrange(mask2, 'b k (h w) -> b k h w', h = 14, w = 14) # hard coded this for now, [b, h, w]
  557. mask1 = F.interpolate(mask1, size=(224, 224), mode='bilinear', align_corners=True)
  558. mask2 = F.interpolate(mask2, size=(224, 224), mode='bilinear', align_corners=True)
  559. ###### get the representation of the sampled_noun and measure the similarity ###############
  560. if cross_entity is not None:
  561. with torch.no_grad():
  562. noun_feat = self.encode_text(cross_entity, as_dict=True)['text_x'] # [bs, d_c]
  563. group_logits = (groups @ noun_feat.unsqueeze(-1)).squeeze(-1) #[bs, k]
  564. num_groups = group_logits.size(1)
  565. topk_logits = torch.topk(group_logits, k=int(num_groups*self.group_ratio), largest=False)[1]
  566. mask1[torch.arange(bs).unsqueeze(1), topk_logits] = mask1[torch.arange(bs).unsqueeze(1), topk_logits].detach()
  567. ############################################################################################
  568. return self.mask_loss(mask1, mask2.detach(), self.cross_threshold, indicator=indicator)
  569. maskloss_cross = compute_cross_loss(attn_map_cross1, attn_map3, cross_entity, image_outs['image_feat'], indicator='none')
  570. if self.dual_dice:
  571. dual_image_outs = self.encode_image(image, as_dict=True, return_feat=True, return_attn=True, momentum=True)
  572. dual_image_outs3 = self.encode_image(cross_image, as_dict=True, return_feat=True, return_attn=True)
  573. dual_attn_q = dual_image_outs['attn_dicts'][0]['q'].squeeze(1)
  574. dual_attn_k = dual_image_outs['attn_dicts'][0]['k'].squeeze(1)
  575. dual_attn_q_cross = dual_image_outs3['attn_dicts'][0]['q'].squeeze(1)
  576. dual_attn_k_cross = dual_image_outs3['attn_dicts'][0]['k'].squeeze(1)
  577. dual_attn_map = self.project_and_mask(dual_attn_q, dual_attn_k)
  578. dual_attn_map_cross = self.project_and_mask(dual_attn_q_cross, dual_attn_k)
  579. dual_maskloss = compute_cross_loss(dual_attn_map_cross, dual_attn_map, cross_entity, dual_image_outs3['image_feat'], indicator='cross')
  580. maskloss_cross = (maskloss_cross + dual_maskloss) * 0.5
  581. maskloss = maskloss_cross
  582. losses_dict['mask'] = maskloss
  583. ############################################################
  584. ### total loss ###
  585. if self.use_entityloss and self.use_maskloss: ### for 2nd stage ###
  586. losses = matchingloss + self.entity_weight * entityloss + self.maskloss_weight * maskloss
  587. elif self.use_entityloss: ### for 1st stage ###
  588. losses = matchingloss + self.entity_weight * entityloss
  589. else: ### baseline ###
  590. losses = matchingloss
  591. if self.with_multi_label:
  592. image_multi_label_x = image_x.unsqueeze(1)
  593. text_multi_label_x = text_outs['text_multi_label_x']
  594. loss_multi_label = self.multi_label_loss(image_multi_label_x, text_multi_label_x) * self.multi_label_loss_weight
  595. losses_dict['multi_label'] = loss_multi_label
  596. losses += loss_multi_label
  597. losses_dict['loss'] = losses
  598. return losses_dict
  599. def forward_test(self, image, text):
  600. return self.zero_shot_pred(image, text)
  601. def forward(self, image, text, cross_image=None, cross_entity=None, \
  602. question=None, answer=None, entity_masks=None, question_masks=None):
  603. """
  604. Args:
  605. image: [b, 3, 224, 224] raw input image
  606. text: [b, L] caption embedding after tokenisation with length L
  607. cross_image: [b, 3, 224, 224] the image that shares the same entity with the input image
  608. cross_entity: [b, L] text embedding of the shared entity after tokenisation
  609. question: [b, L] question embedding after tokenisation
  610. answer: [b, L] prompted answer embedding after tokenisation
  611. entity_masks: [b, L]
  612. question_masks: [b, L]
  613. """
  614. if self.training:
  615. return self.forward_train(image=image, text=text, cross_image=cross_image, cross_entity=cross_entity, \
  616. question=question, answer=answer, entity_masks=entity_masks, question_masks=question_masks)
  617. else:
  618. return self.forward_test(image, text)
  619. @torch.no_grad()
  620. def build_text_embedding(self, text, tokenizer=None, num_classes=20):
  621. """
  622. Args:
  623. text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH]
  624. distilbert:
  625. text (list) [classes * numtemplates] for distilbert, num_classes: 20 for voc by default, 1000 for IN1K
  626. num_classes 暂时没用
  627. Returns:
  628. """
  629. if self.text_encoder_type in ['DistilBert','Bert', 'BertMedium', 'Roberta']:
  630. assert tokenizer is not None
  631. text_data = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
  632. text_data = {key: val.cuda() for key, val in text_data.items()}
  633. text_tokens = self.encode_text(text_data, as_dict=True, forward_template=True)['text_x']
  634. else:
  635. text = text.to(next(self.parameters()).device)
  636. num_classes, num_templates = text.shape[:2]
  637. text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates)
  638. text_tokens = self.encode_text(text, as_dict=True, forward_template=True)['text_x']
  639. # [N, T, C]
  640. # text_tokens = rearrange(text_tokens, '(n t) c -> n t c', n=num_classes, t=num_templates)
  641. text_tokens = rearrange(text_tokens, '(n t) c -> n t c', n=num_classes)
  642. # [N, C]
  643. text_tokens = text_tokens.mean(dim=1)
  644. text_tokens = F.normalize(text_tokens, dim=-1)
  645. return text_tokens
  646. @torch.no_grad()
  647. def build_text_embedding_without_projection(self, text):
  648. """
  649. Args:
  650. text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH]
  651. Returns:
  652. """
  653. text = text.to(next(self.parameters()).device)
  654. num_classes, num_templates = text.shape[:2]
  655. text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates)
  656. text_tokens = self.encode_text(text, as_dict=True, forward_template=True)['text_x_before_proj']
  657. # [N, T, C]
  658. text_tokens = rearrange(text_tokens, '(n t) c -> n t c', n=num_classes, t=num_templates)
  659. # [N, C]
  660. text_tokens = text_tokens.mean(dim=1)
  661. text_tokens = F.normalize(text_tokens, dim=-1)
  662. return text_tokens
  663. @torch.no_grad()
  664. def zero_shot_pred(self, image, text):
  665. # [B, C]
  666. image_features = self.encode_image(image)
  667. image_features = F.normalize(image_features, dim=-1)
  668. # cosine similarity as logits
  669. logits_per_image = image_features @ text.t()
  670. return logits_per_image