123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- # -------------------------------------------------------------------------
- # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
- #
- # This work is made available under the Nvidia Source Code License.
- # To view a copy of this license, visit
- # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
- #
- # Written by Jiarui Xu
- # -------------------------------------------------------------------------
- import diffdist.functional as diff_dist
- import numpy as np
- import torch
- import torch.distributed as dist
- import torch.nn as nn
- import torch.nn.functional as F
- from einops import rearrange, repeat
- from timm.loss import SoftTargetCrossEntropy
- from .builder import MODELS
- from .misc import Result
- def dist_collect(x):
- """ collect all tensor from all GPUs
- args:
- x: shape (mini_batch, ...)
- returns:
- shape (mini_batch * num_gpu, ...)
- """
- x = x.contiguous()
- out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
- out_list = diff_dist.all_gather(out_list, x)
- return torch.cat(out_list, dim=0).contiguous()
- class ProjectMLP(nn.Module):
- def __init__(self, in_dim=256, inner_dim=4096, out_dim=256, num_layers=2):
- super(ProjectMLP, self).__init__()
- # hidden layers
- linear_hidden = []
- for i in range(num_layers - 1):
- linear_hidden.append(nn.Conv1d(in_dim if i == 0 else inner_dim, inner_dim, kernel_size=1))
- linear_hidden.append(nn.BatchNorm1d(inner_dim))
- linear_hidden.append(nn.ReLU(inplace=True))
- self.linear_hidden = nn.Sequential(*linear_hidden)
- self.linear_out = nn.Conv1d(
- in_dim if num_layers == 1 else inner_dim, out_dim, kernel_size=1) if num_layers >= 1 else nn.Identity()
- def forward(self, x):
- """
- Args:
- x (torch.Tensor): output of transformers, shape [B, L, C]
- Returns:
- """
- assert x.ndim in [2, 3], x.ndim
- add_dim = False
- if x.ndim == 2:
- # [B, C] -> [B, L, C]
- x = x.unsqueeze(1)
- add_dim = True
- x = rearrange(x, 'b l c -> b c l')
- x = self.linear_hidden(x)
- x = self.linear_out(x)
- x = rearrange(x, 'b c l -> b l c')
- if add_dim:
- x = x.squeeze(1)
- return x
- @MODELS.register_module()
- class MultiLabelContrastive(nn.Module):
- def __init__(self,
- img_encoder,
- text_encoder,
- output_dim=256,
- contrast_temperature=0.07,
- proj_num_layers=2,
- multi_label=0,
- share_temperature=False,
- multi_label_loss_weight=1.0):
- super().__init__()
- self.img_encoder = MODELS.build(img_encoder)
- self.text_encoder = MODELS.build(text_encoder)
- self.contrast_temperature = contrast_temperature
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
- self.cross_entropy = nn.CrossEntropyLoss()
- self.soft_cross_entropy = SoftTargetCrossEntropy()
- self.proj_num_layers = proj_num_layers
- self.multi_label = multi_label
- if proj_num_layers > 0:
- self.img_projector = ProjectMLP(
- in_dim=self.img_encoder.width, num_layers=proj_num_layers, out_dim=output_dim)
- self.text_projector = ProjectMLP(
- in_dim=self.text_encoder.width, num_layers=proj_num_layers, out_dim=output_dim)
- self.img_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.img_projector)
- self.text_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.text_projector)
- else:
- self.img_projector = nn.Identity()
- self.text_projector = nn.Identity()
- self.share_temperature = share_temperature
- if self.with_multi_label and not self.share_temperature:
- self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
- self.multi_label_loss_weight = multi_label_loss_weight
- @property
- def with_multi_label(self):
- return self.multi_label > 0
- def loss(self, image_x, text_x):
- batch_size = image_x.shape[0]
- # get label globally
- labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device) + batch_size * dist.get_rank()
- # [B, C]
- image_x = F.normalize(image_x, dim=-1)
- text_x = F.normalize(text_x, dim=-1)
- logits_per_img = image_x @ dist_collect(text_x).t()
- logits_per_text = text_x @ dist_collect(image_x).t()
- logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
- loss_img = self.cross_entropy(logits_per_img * logit_scale, labels)
- loss_text = self.cross_entropy(logits_per_text * logit_scale, labels)
- loss = 0.5 * (loss_img + loss_text)
- return loss
- def multi_label_loss(self, image_feat, text_feat):
- """
- Args:
- image_feat (torch.Tensor): shape [B, L1, C]
- text_feat (torch.Tensor): shape [B, L2, C]
- Returns:
- """
- # [B, L1, C], L1 = 1
- image_feat = F.normalize(image_feat, dim=-1)
- # [B, L2, C]
- text_feat = F.normalize(text_feat, dim=-1)
- # [B, L1, L2]
- dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')
- # [B, L2, L1]
- dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')
- if self.share_temperature:
- logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
- else:
- logit_scale = torch.clamp(self.multi_label_logit_scale.exp(), max=100)
- batch = image_feat.shape[0]
- img_len = image_feat.shape[1]
- text_len = text_feat.shape[1]
- # [B, L1, L2]
- pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')
- # [B, L2, L1]
- pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')
- image_x = rearrange(image_feat, 'b l c -> (b l) c')
- text_x = rearrange(text_feat, 'b l c -> (b l) c')
- logits_per_img = image_x @ dist_collect(text_x).t()
- logits_per_text = text_x @ dist_collect(image_x).t()
- # get label globally
- # [B, L1, B, L2, W]
- labels_per_img = F.one_hot(
- torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * dist.get_rank(),
- num_classes=dist.get_world_size()).to(image_x.dtype)
- labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(
- torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')
- # [BxL1, WxBxL2]
- labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')
- # [B, L2, B, L1, W]
- labels_per_text = F.one_hot(
- torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * dist.get_rank(),
- num_classes=dist.get_world_size()).to(text_x.dtype)
- labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(
- torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')
- # [BxL2, WxBxL1]
- labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')
- loss_img = self.soft_cross_entropy(logits_per_img * logit_scale, labels_per_img)
- loss_text = self.soft_cross_entropy(logits_per_text * logit_scale, labels_per_text)
- loss = 0.5 * (loss_img + loss_text)
- return loss
- def encode_image(self, image, *, return_feat=False, as_dict=False):
- outs = Result(as_dict)
- img_outs = self.img_encoder(image, return_feat=return_feat, as_dict=True)
- outs.append(self.img_projector(img_outs['x']), 'image_x')
- if return_feat:
- outs.append(self.img_projector(img_outs['feat']), 'image_feat')
- return outs.as_return()
- def encode_text(self, text, *, as_dict=False):
- assert text.ndim in [2, 3], text.ndim
- squeeze_dim = False
- num_text = 1
- if text.ndim == 3:
- num_text = text.shape[1]
- text = rearrange(text, 'b n l -> (b n) l', n=num_text)
- squeeze_dim = True
- outs = Result(as_dict=as_dict)
- # [B, C]
- x = self.text_encoder(text)
- text_x = self.text_projector(x)
- outs.append(text_x, 'text_x')
- if squeeze_dim:
- text_x = rearrange(text_x, '(b n) c -> b n c', n=num_text)
- text_multi_label_x = text_x[:, 1:]
- text_x = text_x[:, 0]
- outs.update(text_x=text_x, text_multi_label_x=text_multi_label_x)
- return outs.as_return()
- def forward_train(self, image, text):
- image_outs = self.encode_image(image, as_dict=True)
- # [B, C]
- image_x = image_outs['image_x']
- text_outs = self.encode_text(text, as_dict=True)
- # [B, C]
- text_x = text_outs['text_x']
- losses = self.loss(image_x, text_x)
- losses_dict = dict(loss=losses)
- if self.with_multi_label:
- image_multi_label_x = image_x.unsqueeze(1)
- text_multi_label_x = text_outs['text_multi_label_x']
- losses_dict['multi_label_loss'] = self.multi_label_loss(image_multi_label_x,
- text_multi_label_x) * self.multi_label_loss_weight
- return losses_dict
- def forward_test(self, image, text):
- return self.zero_shot_pred(image, text)
- def forward(self, image, text):
- if self.training:
- return self.forward_train(image, text)
- else:
- return self.forward_test(image, text)
- @torch.no_grad()
- def build_text_embedding(self, text):
- """
- Args:
- text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH]
- Returns:
- """
- text = text.to(next(self.parameters()).device)
- num_classes, num_templates = text.shape[:2]
- text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates)
- text_tokens = self.encode_text(text)
- # [N, T, C]
- text_tokens = rearrange(text_tokens, '(n t) c -> n t c', n=num_classes, t=num_templates)
- # [N, C]
- text_tokens = text_tokens.mean(dim=1)
- text_tokens = F.normalize(text_tokens, dim=-1)
- return text_tokens
- @torch.no_grad()
- def zero_shot_pred(self, image, text):
- # [B, C]
- image_features = self.encode_image(image)
- image_features = F.normalize(image_features, dim=-1)
- # cosine similarity as logits
- logits_per_image = image_features @ text.t()
- return logits_per_image
|