multi_label_contrastive.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # -------------------------------------------------------------------------
  10. import diffdist.functional as diff_dist
  11. import numpy as np
  12. import torch
  13. import torch.distributed as dist
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. from einops import rearrange, repeat
  17. from timm.loss import SoftTargetCrossEntropy
  18. from .builder import MODELS
  19. from .misc import Result
  20. def dist_collect(x):
  21. """ collect all tensor from all GPUs
  22. args:
  23. x: shape (mini_batch, ...)
  24. returns:
  25. shape (mini_batch * num_gpu, ...)
  26. """
  27. x = x.contiguous()
  28. out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
  29. out_list = diff_dist.all_gather(out_list, x)
  30. return torch.cat(out_list, dim=0).contiguous()
  31. class ProjectMLP(nn.Module):
  32. def __init__(self, in_dim=256, inner_dim=4096, out_dim=256, num_layers=2):
  33. super(ProjectMLP, self).__init__()
  34. # hidden layers
  35. linear_hidden = []
  36. for i in range(num_layers - 1):
  37. linear_hidden.append(nn.Conv1d(in_dim if i == 0 else inner_dim, inner_dim, kernel_size=1))
  38. linear_hidden.append(nn.BatchNorm1d(inner_dim))
  39. linear_hidden.append(nn.ReLU(inplace=True))
  40. self.linear_hidden = nn.Sequential(*linear_hidden)
  41. self.linear_out = nn.Conv1d(
  42. in_dim if num_layers == 1 else inner_dim, out_dim, kernel_size=1) if num_layers >= 1 else nn.Identity()
  43. def forward(self, x):
  44. """
  45. Args:
  46. x (torch.Tensor): output of transformers, shape [B, L, C]
  47. Returns:
  48. """
  49. assert x.ndim in [2, 3], x.ndim
  50. add_dim = False
  51. if x.ndim == 2:
  52. # [B, C] -> [B, L, C]
  53. x = x.unsqueeze(1)
  54. add_dim = True
  55. x = rearrange(x, 'b l c -> b c l')
  56. x = self.linear_hidden(x)
  57. x = self.linear_out(x)
  58. x = rearrange(x, 'b c l -> b l c')
  59. if add_dim:
  60. x = x.squeeze(1)
  61. return x
  62. @MODELS.register_module()
  63. class MultiLabelContrastive(nn.Module):
  64. def __init__(self,
  65. img_encoder,
  66. text_encoder,
  67. output_dim=256,
  68. contrast_temperature=0.07,
  69. proj_num_layers=2,
  70. multi_label=0,
  71. share_temperature=False,
  72. multi_label_loss_weight=1.0):
  73. super().__init__()
  74. self.img_encoder = MODELS.build(img_encoder)
  75. self.text_encoder = MODELS.build(text_encoder)
  76. self.contrast_temperature = contrast_temperature
  77. self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
  78. self.cross_entropy = nn.CrossEntropyLoss()
  79. self.soft_cross_entropy = SoftTargetCrossEntropy()
  80. self.proj_num_layers = proj_num_layers
  81. self.multi_label = multi_label
  82. if proj_num_layers > 0:
  83. self.img_projector = ProjectMLP(
  84. in_dim=self.img_encoder.width, num_layers=proj_num_layers, out_dim=output_dim)
  85. self.text_projector = ProjectMLP(
  86. in_dim=self.text_encoder.width, num_layers=proj_num_layers, out_dim=output_dim)
  87. self.img_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.img_projector)
  88. self.text_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.text_projector)
  89. else:
  90. self.img_projector = nn.Identity()
  91. self.text_projector = nn.Identity()
  92. self.share_temperature = share_temperature
  93. if self.with_multi_label and not self.share_temperature:
  94. self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
  95. self.multi_label_loss_weight = multi_label_loss_weight
  96. @property
  97. def with_multi_label(self):
  98. return self.multi_label > 0
  99. def loss(self, image_x, text_x):
  100. batch_size = image_x.shape[0]
  101. # get label globally
  102. labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device) + batch_size * dist.get_rank()
  103. # [B, C]
  104. image_x = F.normalize(image_x, dim=-1)
  105. text_x = F.normalize(text_x, dim=-1)
  106. logits_per_img = image_x @ dist_collect(text_x).t()
  107. logits_per_text = text_x @ dist_collect(image_x).t()
  108. logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
  109. loss_img = self.cross_entropy(logits_per_img * logit_scale, labels)
  110. loss_text = self.cross_entropy(logits_per_text * logit_scale, labels)
  111. loss = 0.5 * (loss_img + loss_text)
  112. return loss
  113. def multi_label_loss(self, image_feat, text_feat):
  114. """
  115. Args:
  116. image_feat (torch.Tensor): shape [B, L1, C]
  117. text_feat (torch.Tensor): shape [B, L2, C]
  118. Returns:
  119. """
  120. # [B, L1, C], L1 = 1
  121. image_feat = F.normalize(image_feat, dim=-1)
  122. # [B, L2, C]
  123. text_feat = F.normalize(text_feat, dim=-1)
  124. # [B, L1, L2]
  125. dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')
  126. # [B, L2, L1]
  127. dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')
  128. if self.share_temperature:
  129. logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
  130. else:
  131. logit_scale = torch.clamp(self.multi_label_logit_scale.exp(), max=100)
  132. batch = image_feat.shape[0]
  133. img_len = image_feat.shape[1]
  134. text_len = text_feat.shape[1]
  135. # [B, L1, L2]
  136. pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')
  137. # [B, L2, L1]
  138. pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')
  139. image_x = rearrange(image_feat, 'b l c -> (b l) c')
  140. text_x = rearrange(text_feat, 'b l c -> (b l) c')
  141. logits_per_img = image_x @ dist_collect(text_x).t()
  142. logits_per_text = text_x @ dist_collect(image_x).t()
  143. # get label globally
  144. # [B, L1, B, L2, W]
  145. labels_per_img = F.one_hot(
  146. torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * dist.get_rank(),
  147. num_classes=dist.get_world_size()).to(image_x.dtype)
  148. labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(
  149. torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')
  150. # [BxL1, WxBxL2]
  151. labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')
  152. # [B, L2, B, L1, W]
  153. labels_per_text = F.one_hot(
  154. torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * dist.get_rank(),
  155. num_classes=dist.get_world_size()).to(text_x.dtype)
  156. labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(
  157. torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')
  158. # [BxL2, WxBxL1]
  159. labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')
  160. loss_img = self.soft_cross_entropy(logits_per_img * logit_scale, labels_per_img)
  161. loss_text = self.soft_cross_entropy(logits_per_text * logit_scale, labels_per_text)
  162. loss = 0.5 * (loss_img + loss_text)
  163. return loss
  164. def encode_image(self, image, *, return_feat=False, as_dict=False):
  165. outs = Result(as_dict)
  166. img_outs = self.img_encoder(image, return_feat=return_feat, as_dict=True)
  167. outs.append(self.img_projector(img_outs['x']), 'image_x')
  168. if return_feat:
  169. outs.append(self.img_projector(img_outs['feat']), 'image_feat')
  170. return outs.as_return()
  171. def encode_text(self, text, *, as_dict=False):
  172. assert text.ndim in [2, 3], text.ndim
  173. squeeze_dim = False
  174. num_text = 1
  175. if text.ndim == 3:
  176. num_text = text.shape[1]
  177. text = rearrange(text, 'b n l -> (b n) l', n=num_text)
  178. squeeze_dim = True
  179. outs = Result(as_dict=as_dict)
  180. # [B, C]
  181. x = self.text_encoder(text)
  182. text_x = self.text_projector(x)
  183. outs.append(text_x, 'text_x')
  184. if squeeze_dim:
  185. text_x = rearrange(text_x, '(b n) c -> b n c', n=num_text)
  186. text_multi_label_x = text_x[:, 1:]
  187. text_x = text_x[:, 0]
  188. outs.update(text_x=text_x, text_multi_label_x=text_multi_label_x)
  189. return outs.as_return()
  190. def forward_train(self, image, text):
  191. image_outs = self.encode_image(image, as_dict=True)
  192. # [B, C]
  193. image_x = image_outs['image_x']
  194. text_outs = self.encode_text(text, as_dict=True)
  195. # [B, C]
  196. text_x = text_outs['text_x']
  197. losses = self.loss(image_x, text_x)
  198. losses_dict = dict(loss=losses)
  199. if self.with_multi_label:
  200. image_multi_label_x = image_x.unsqueeze(1)
  201. text_multi_label_x = text_outs['text_multi_label_x']
  202. losses_dict['multi_label_loss'] = self.multi_label_loss(image_multi_label_x,
  203. text_multi_label_x) * self.multi_label_loss_weight
  204. return losses_dict
  205. def forward_test(self, image, text):
  206. return self.zero_shot_pred(image, text)
  207. def forward(self, image, text):
  208. if self.training:
  209. return self.forward_train(image, text)
  210. else:
  211. return self.forward_test(image, text)
  212. @torch.no_grad()
  213. def build_text_embedding(self, text):
  214. """
  215. Args:
  216. text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH]
  217. Returns:
  218. """
  219. text = text.to(next(self.parameters()).device)
  220. num_classes, num_templates = text.shape[:2]
  221. text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates)
  222. text_tokens = self.encode_text(text)
  223. # [N, T, C]
  224. text_tokens = rearrange(text_tokens, '(n t) c -> n t c', n=num_classes, t=num_templates)
  225. # [N, C]
  226. text_tokens = text_tokens.mean(dim=1)
  227. text_tokens = F.normalize(text_tokens, dim=-1)
  228. return text_tokens
  229. @torch.no_grad()
  230. def zero_shot_pred(self, image, text):
  231. # [B, C]
  232. image_features = self.encode_image(image)
  233. image_features = F.normalize(image_features, dim=-1)
  234. # cosine similarity as logits
  235. logits_per_image = image_features @ text.t()
  236. return logits_per_image