multi_label_contrastive.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. import diffdist.functional as diff_dist
  14. import numpy as np
  15. import torch
  16. import torch.distributed as dist
  17. import torch.nn as nn
  18. import torch.nn.functional as F
  19. from einops import rearrange, repeat
  20. from timm.loss import SoftTargetCrossEntropy
  21. from .builder import MODELS
  22. from .misc import Result
  23. def dist_collect(x):
  24. """ collect all tensor from all GPUs
  25. args:
  26. x: shape (mini_batch, ...)
  27. returns:
  28. shape (mini_batch * num_gpu, ...)
  29. """
  30. x = x.contiguous()
  31. out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
  32. out_list = diff_dist.all_gather(out_list, x)
  33. return torch.cat(out_list, dim=0).contiguous()
  34. class ProjectMLP(nn.Module):
  35. def __init__(self, in_dim=256, inner_dim=4096, out_dim=256, num_layers=2):
  36. super(ProjectMLP, self).__init__()
  37. # hidden layers
  38. linear_hidden = []
  39. for i in range(num_layers - 1):
  40. linear_hidden.append(nn.Conv1d(in_dim if i == 0 else inner_dim, inner_dim, kernel_size=1))
  41. linear_hidden.append(nn.BatchNorm1d(inner_dim))
  42. linear_hidden.append(nn.ReLU(inplace=True))
  43. self.linear_hidden = nn.Sequential(*linear_hidden)
  44. self.linear_out = nn.Conv1d(
  45. in_dim if num_layers == 1 else inner_dim, out_dim, kernel_size=1) if num_layers >= 1 else nn.Identity()
  46. def forward(self, x):
  47. """
  48. Args:
  49. x (torch.Tensor): output of transformers, shape [B, L, C]
  50. Returns:
  51. """
  52. assert x.ndim in [2, 3], x.ndim
  53. add_dim = False
  54. if x.ndim == 2:
  55. # [B, C] -> [B, L, C]
  56. x = x.unsqueeze(1)
  57. add_dim = True
  58. x = rearrange(x, 'b l c -> b c l')
  59. x = self.linear_hidden(x)
  60. x = self.linear_out(x)
  61. x = rearrange(x, 'b c l -> b l c')
  62. if add_dim:
  63. x = x.squeeze(1)
  64. return x
  65. @MODELS.register_module()
  66. class MultiLabelContrastive(nn.Module):
  67. def __init__(self,
  68. img_encoder,
  69. text_encoder,
  70. output_dim=256,
  71. contrast_temperature=0.07,
  72. proj_num_layers=2,
  73. multi_label=0,
  74. share_temperature=False,
  75. multi_label_loss_weight=1.0):
  76. super().__init__()
  77. self.img_encoder = MODELS.build(img_encoder)
  78. self.text_encoder = MODELS.build(text_encoder)
  79. self.contrast_temperature = contrast_temperature
  80. self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
  81. self.cross_entropy = nn.CrossEntropyLoss()
  82. self.soft_cross_entropy = SoftTargetCrossEntropy()
  83. self.proj_num_layers = proj_num_layers
  84. self.multi_label = multi_label
  85. if proj_num_layers > 0:
  86. self.img_projector = ProjectMLP(
  87. in_dim=self.img_encoder.width, num_layers=proj_num_layers, out_dim=output_dim)
  88. self.text_projector = ProjectMLP(
  89. in_dim=self.text_encoder.width, num_layers=proj_num_layers, out_dim=output_dim)
  90. self.img_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.img_projector)
  91. self.text_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.text_projector)
  92. else:
  93. self.img_projector = nn.Identity()
  94. self.text_projector = nn.Identity()
  95. self.share_temperature = share_temperature
  96. if self.with_multi_label and not self.share_temperature:
  97. self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
  98. self.multi_label_loss_weight = multi_label_loss_weight
  99. @property
  100. def with_multi_label(self):
  101. return self.multi_label > 0
  102. def loss(self, image_x, text_x):
  103. batch_size = image_x.shape[0]
  104. # get label globally
  105. labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device) + batch_size * dist.get_rank()
  106. # [B, C]
  107. image_x = F.normalize(image_x, dim=-1)
  108. text_x = F.normalize(text_x, dim=-1)
  109. logits_per_img = image_x @ dist_collect(text_x).t()
  110. logits_per_text = text_x @ dist_collect(image_x).t()
  111. logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
  112. loss_img = self.cross_entropy(logits_per_img * logit_scale, labels)
  113. loss_text = self.cross_entropy(logits_per_text * logit_scale, labels)
  114. loss = 0.5 * (loss_img + loss_text)
  115. return loss
  116. def multi_label_loss(self, image_feat, text_feat):
  117. """
  118. Args:
  119. image_feat (torch.Tensor): shape [B, L1, C]
  120. text_feat (torch.Tensor): shape [B, L2, C]
  121. Returns:
  122. """
  123. # [B, L1, C], L1 = 1
  124. image_feat = F.normalize(image_feat, dim=-1)
  125. # [B, L2, C]
  126. text_feat = F.normalize(text_feat, dim=-1)
  127. # [B, L1, L2]
  128. dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')
  129. # [B, L2, L1]
  130. dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')
  131. if self.share_temperature:
  132. logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
  133. else:
  134. logit_scale = torch.clamp(self.multi_label_logit_scale.exp(), max=100)
  135. batch = image_feat.shape[0]
  136. img_len = image_feat.shape[1]
  137. text_len = text_feat.shape[1]
  138. # [B, L1, L2]
  139. pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')
  140. # [B, L2, L1]
  141. pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')
  142. image_x = rearrange(image_feat, 'b l c -> (b l) c')
  143. text_x = rearrange(text_feat, 'b l c -> (b l) c')
  144. logits_per_img = image_x @ dist_collect(text_x).t()
  145. logits_per_text = text_x @ dist_collect(image_x).t()
  146. # get label globally
  147. # [B, L1, B, L2, W]
  148. labels_per_img = F.one_hot(
  149. torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * dist.get_rank(),
  150. num_classes=dist.get_world_size()).to(image_x.dtype)
  151. labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(
  152. torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')
  153. # [BxL1, WxBxL2]
  154. labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')
  155. # [B, L2, B, L1, W]
  156. labels_per_text = F.one_hot(
  157. torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * dist.get_rank(),
  158. num_classes=dist.get_world_size()).to(text_x.dtype)
  159. labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(
  160. torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')
  161. # [BxL2, WxBxL1]
  162. labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')
  163. loss_img = self.soft_cross_entropy(logits_per_img * logit_scale, labels_per_img)
  164. loss_text = self.soft_cross_entropy(logits_per_text * logit_scale, labels_per_text)
  165. loss = 0.5 * (loss_img + loss_text)
  166. return loss
  167. def encode_image(self, image, *, return_feat=False, as_dict=False):
  168. outs = Result(as_dict)
  169. img_outs = self.img_encoder(image, return_feat=return_feat, as_dict=True)
  170. outs.append(self.img_projector(img_outs['x']), 'image_x')
  171. if return_feat:
  172. outs.append(self.img_projector(img_outs['feat']), 'image_feat')
  173. return outs.as_return()
  174. def encode_text(self, text, *, as_dict=False):
  175. assert text.ndim in [2, 3], text.ndim
  176. squeeze_dim = False
  177. num_text = 1
  178. if text.ndim == 3:
  179. num_text = text.shape[1]
  180. text = rearrange(text, 'b n l -> (b n) l', n=num_text)
  181. squeeze_dim = True
  182. outs = Result(as_dict=as_dict)
  183. # [B, C]
  184. x = self.text_encoder(text)
  185. text_x = self.text_projector(x)
  186. outs.append(text_x, 'text_x')
  187. if squeeze_dim:
  188. text_x = rearrange(text_x, '(b n) c -> b n c', n=num_text)
  189. text_multi_label_x = text_x[:, 1:]
  190. text_x = text_x[:, 0]
  191. outs.update(text_x=text_x, text_multi_label_x=text_multi_label_x)
  192. return outs.as_return()
  193. def forward_train(self, image, text):
  194. image_outs = self.encode_image(image, as_dict=True)
  195. # [B, C]
  196. image_x = image_outs['image_x']
  197. text_outs = self.encode_text(text, as_dict=True)
  198. # [B, C]
  199. text_x = text_outs['text_x']
  200. losses = self.loss(image_x, text_x)
  201. losses_dict = dict(loss=losses)
  202. if self.with_multi_label:
  203. image_multi_label_x = image_x.unsqueeze(1)
  204. text_multi_label_x = text_outs['text_multi_label_x']
  205. losses_dict['multi_label_loss'] = self.multi_label_loss(image_multi_label_x,
  206. text_multi_label_x) * self.multi_label_loss_weight
  207. return losses_dict
  208. def forward_test(self, image, text):
  209. return self.zero_shot_pred(image, text)
  210. def forward(self, image, text):
  211. if self.training:
  212. return self.forward_train(image, text)
  213. else:
  214. return self.forward_test(image, text)
  215. @torch.no_grad()
  216. def build_text_embedding(self, text):
  217. """
  218. Args:
  219. text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH]
  220. Returns:
  221. """
  222. text = text.to(next(self.parameters()).device)
  223. num_classes, num_templates = text.shape[:2]
  224. text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates)
  225. text_tokens = self.encode_text(text)
  226. # [N, T, C]
  227. text_tokens = rearrange(text_tokens, '(n t) c -> n t c', n=num_classes, t=num_templates)
  228. # [N, C]
  229. text_tokens = text_tokens.mean(dim=1)
  230. text_tokens = F.normalize(text_tokens, dim=-1)
  231. return text_tokens
  232. @torch.no_grad()
  233. def zero_shot_pred(self, image, text):
  234. # [B, C]
  235. image_features = self.encode_image(image)
  236. image_features = F.normalize(image_features, dim=-1)
  237. # cosine similarity as logits
  238. logits_per_image = image_features @ text.t()
  239. return logits_per_image