123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- def compute_sdm(image_fetures, text_fetures, pid, logit_scale, image_id=None, factor=0.3, epsilon=1e-8):
- """
- Similarity Distribution Matching
- """
- batch_size = image_fetures.shape[0]
- pid = pid.reshape((batch_size, 1)) # make sure pid size is [batch_size, 1]
- pid_dist = pid - pid.t()
- labels = (pid_dist == 0).float()
- if image_id != None:
- # print("Mix PID and ImageID to create soft label.")
- image_id = image_id.reshape((-1, 1))
- image_id_dist = image_id - image_id.t()
- image_id_mask = (image_id_dist == 0).float()
- labels = (labels - image_id_mask) * factor + image_id_mask
- # labels = (labels + image_id_mask) / 2
- image_norm = image_fetures / image_fetures.norm(dim=1, keepdim=True)
- text_norm = text_fetures / text_fetures.norm(dim=1, keepdim=True)
- t2i_cosine_theta = text_norm @ image_norm.t()
- i2t_cosine_theta = t2i_cosine_theta.t()
- text_proj_image = logit_scale * t2i_cosine_theta
- image_proj_text = logit_scale * i2t_cosine_theta
- # normalize the true matching distribution
- labels_distribute = labels / labels.sum(dim=1)
- i2t_pred = F.softmax(image_proj_text, dim=1)
- i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_distribute + epsilon))
- t2i_pred = F.softmax(text_proj_image, dim=1)
- t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_distribute + epsilon))
- loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1))
- return loss
- def compute_mlm(scores, labels):
- ce = nn.CrossEntropyLoss(ignore_index=0)
- return ce(scores, labels)
- def compute_itc(image_features, text_features, logit_scale):
- """
- image-text contrastive (ITC) loss, InfoNCE
- """
- batch_size = image_features.shape[0]
- labels = torch.arange(start=0, end=batch_size, dtype=torch.int64)
- labels = labels.to(image_features.device)
-
- # normalized features
- image_norm = image_features / image_features.norm(dim=-1, keepdim=True)
- text_norm = text_features / text_features.norm(dim=-1, keepdim=True)
- # cosine similarity as logits
- logits_per_image = logit_scale * image_norm @ text_norm.t()
- logits_per_text = logits_per_image.t()
- loss_i = F.cross_entropy(logits_per_image, labels)
- loss_t =F.cross_entropy(logits_per_text, labels)
- loss = (loss_i + loss_t)/2
- return loss
- def compute_id(image_logits, text_logits, labels):
- """
- Instance loss proposed at http://arxiv.org/abs/1711.05535
- """
- criterion = nn.CrossEntropyLoss(reduction="mean")
- loss = criterion(image_logits, labels) + criterion(text_logits, labels)
-
- return loss / 2
- def compute_cmpm(image_embeddings, text_embeddings, labels, epsilon=1e-8):
- """
- Cross-Modal Projection Matching Loss(CMPM)
- :param image_embeddings: Tensor with dtype torch.float32
- :param text_embeddings: Tensor with dtype torch.float32
- :param labels: Tensor with dtype torch.int32
- :return:
- i2t_loss: cmpm loss for image projected to text
- t2i_loss: cmpm loss for text projected to image
- pos_avg_sim: average cosine-similarity for positive pairs
- neg_avg_sim: averate cosine-similarity for negative pairs
- """
- batch_size = image_embeddings.shape[0]
- labels_reshape = torch.reshape(labels, (batch_size, 1))
- labels_dist = labels_reshape - labels_reshape.t()
- labels_mask = (labels_dist == 0).float()
- image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
- text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
- image_proj_text = torch.matmul(image_embeddings, text_norm.t())
- text_proj_image = torch.matmul(text_embeddings, image_norm.t())
- # normalize the true matching distribution
- labels_mask_norm = labels_mask / labels_mask.norm(dim=1)
- i2t_pred = F.softmax(image_proj_text, dim=1)
- i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_mask_norm + epsilon))
- t2i_pred = F.softmax(text_proj_image, dim=1)
- t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_mask_norm + epsilon))
- cmpm_loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1))
- return cmpm_loss
|