objectives.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. def compute_sdm(image_fetures, text_fetures, pid, logit_scale, image_id=None, factor=0.3, epsilon=1e-8):
  5. """
  6. Similarity Distribution Matching
  7. """
  8. batch_size = image_fetures.shape[0]
  9. pid = pid.reshape((batch_size, 1)) # make sure pid size is [batch_size, 1]
  10. pid_dist = pid - pid.t()
  11. labels = (pid_dist == 0).float()
  12. if image_id != None:
  13. # print("Mix PID and ImageID to create soft label.")
  14. image_id = image_id.reshape((-1, 1))
  15. image_id_dist = image_id - image_id.t()
  16. image_id_mask = (image_id_dist == 0).float()
  17. labels = (labels - image_id_mask) * factor + image_id_mask
  18. # labels = (labels + image_id_mask) / 2
  19. image_norm = image_fetures / image_fetures.norm(dim=1, keepdim=True)
  20. text_norm = text_fetures / text_fetures.norm(dim=1, keepdim=True)
  21. t2i_cosine_theta = text_norm @ image_norm.t()
  22. i2t_cosine_theta = t2i_cosine_theta.t()
  23. text_proj_image = logit_scale * t2i_cosine_theta
  24. image_proj_text = logit_scale * i2t_cosine_theta
  25. # normalize the true matching distribution
  26. labels_distribute = labels / labels.sum(dim=1)
  27. i2t_pred = F.softmax(image_proj_text, dim=1)
  28. i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_distribute + epsilon))
  29. t2i_pred = F.softmax(text_proj_image, dim=1)
  30. t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_distribute + epsilon))
  31. loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1))
  32. return loss
  33. def compute_mlm(scores, labels):
  34. ce = nn.CrossEntropyLoss(ignore_index=0)
  35. return ce(scores, labels)
  36. def compute_itc(image_features, text_features, logit_scale):
  37. """
  38. image-text contrastive (ITC) loss, InfoNCE
  39. """
  40. batch_size = image_features.shape[0]
  41. labels = torch.arange(start=0, end=batch_size, dtype=torch.int64)
  42. labels = labels.to(image_features.device)
  43. # normalized features
  44. image_norm = image_features / image_features.norm(dim=-1, keepdim=True)
  45. text_norm = text_features / text_features.norm(dim=-1, keepdim=True)
  46. # cosine similarity as logits
  47. logits_per_image = logit_scale * image_norm @ text_norm.t()
  48. logits_per_text = logits_per_image.t()
  49. loss_i = F.cross_entropy(logits_per_image, labels)
  50. loss_t =F.cross_entropy(logits_per_text, labels)
  51. loss = (loss_i + loss_t)/2
  52. return loss
  53. def compute_id(image_logits, text_logits, labels):
  54. """
  55. Instance loss proposed at http://arxiv.org/abs/1711.05535
  56. """
  57. criterion = nn.CrossEntropyLoss(reduction="mean")
  58. loss = criterion(image_logits, labels) + criterion(text_logits, labels)
  59. return loss / 2
  60. def compute_cmpm(image_embeddings, text_embeddings, labels, epsilon=1e-8):
  61. """
  62. Cross-Modal Projection Matching Loss(CMPM)
  63. :param image_embeddings: Tensor with dtype torch.float32
  64. :param text_embeddings: Tensor with dtype torch.float32
  65. :param labels: Tensor with dtype torch.int32
  66. :return:
  67. i2t_loss: cmpm loss for image projected to text
  68. t2i_loss: cmpm loss for text projected to image
  69. pos_avg_sim: average cosine-similarity for positive pairs
  70. neg_avg_sim: averate cosine-similarity for negative pairs
  71. """
  72. batch_size = image_embeddings.shape[0]
  73. labels_reshape = torch.reshape(labels, (batch_size, 1))
  74. labels_dist = labels_reshape - labels_reshape.t()
  75. labels_mask = (labels_dist == 0).float()
  76. image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
  77. text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
  78. image_proj_text = torch.matmul(image_embeddings, text_norm.t())
  79. text_proj_image = torch.matmul(text_embeddings, image_norm.t())
  80. # normalize the true matching distribution
  81. labels_mask_norm = labels_mask / labels_mask.norm(dim=1)
  82. i2t_pred = F.softmax(image_proj_text, dim=1)
  83. i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_mask_norm + epsilon))
  84. t2i_pred = F.softmax(text_proj_image, dim=1)
  85. t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_mask_norm + epsilon))
  86. cmpm_loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1))
  87. return cmpm_loss