losses.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # -------------------------------------------------------------------------
  2. # Written by Jilan Xu
  3. # -------------------------------------------------------------------------
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from einops import rearrange
  8. import numpy as np
  9. from torch import linalg as LA
  10. from scipy.optimize import linear_sum_assignment
  11. # from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
  12. from ipdb import set_trace
  13. import torch.distributed as dist
  14. import diffdist.functional as diff_dist
  15. from ipdb import set_trace
  16. def dist_collect(x):
  17. """ collect all tensor from all GPUs
  18. args:
  19. x: shape (mini_batch, ...)
  20. returns:
  21. shape (mini_batch * num_gpu, ...)
  22. """
  23. x = x.contiguous()
  24. out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
  25. out_list = diff_dist.all_gather(out_list, x)
  26. return torch.cat(out_list, dim=0).contiguous()
  27. class HungarianMatcher(nn.Module):
  28. """This class computes an assignment between the targets and the predictions of the network
  29. For efficiency reasons, the targets don't include the no_object. Because of this, in general,
  30. there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
  31. while the others are un-matched (and thus treated as non-objects).
  32. """
  33. def __init__(self, cost_type='L2'):
  34. """Creates the matcher
  35. Params:
  36. cost_class: This is the relative weight of the classification error in the matching cost
  37. cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
  38. cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
  39. """
  40. super().__init__()
  41. self.cost_type = cost_type
  42. @torch.no_grad()
  43. def forward(self, outputs, targets):
  44. """ Performs the matching
  45. NewParams:
  46. outputs: [b, k, h * w], k normalized masks
  47. targets: [b, k, h * w] k normalized masks
  48. Params:s
  49. outputs: This is a dict that contains at least these entries:
  50. "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
  51. "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
  52. targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
  53. "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
  54. objects in the target) containing the class labels
  55. "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
  56. Returns:
  57. A list of size batch_size, containing tuples of (index_i, index_j) where:
  58. - index_i is the indices of the selected predictions (in order)
  59. - index_j is the indices of the corresponding selected targets (in order)
  60. For each batch element, it holds:
  61. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  62. """
  63. bs, num_queries = outputs.shape[:2]
  64. # We flatten to compute the cost matrices in a batch
  65. # out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
  66. # out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
  67. if self.cost_type == 'L2':
  68. cost_mask = torch.cdist(outputs, targets, p=2) #[b, k, k]
  69. elif self.cost_type == 'cosine':
  70. ##### <a, b> / (||a|| * ||b||) ######
  71. cos_sim = outputs @ targets.transpose(-2, -1) #[b, k, k]
  72. dist_a = LA.norm(outputs, dim=-1).unsqueeze(-1) #[b, k, 1]
  73. dist_b = LA.norm(targets, dim=-1).unsqueeze(-2) #[b, 1, k]
  74. eps = 1e-6
  75. ### negative cosine similarity as cost matrix
  76. cost_mask = -1 * (cos_sim / (dist_a + eps) / (dist_b + eps))
  77. else:
  78. return ValueError
  79. # set_trace()
  80. inds = []
  81. inds2 = []
  82. for i in range(bs):
  83. xx, yy = linear_sum_assignment(cost_mask[i].cpu())
  84. inds.append(xx)
  85. inds2.append(yy)
  86. # indices = [linear_sum_assignment(cost_mask[i]) for i in range(bs)]
  87. # indices = [linear_sum_assignment(c[i].cpu()) for i, c in enumerate(cost_mask.split(bs, -1))]
  88. # indices = [linear_sum_assignment(c[i].cpu()) for i, c in zip(range(bs), cost_mask)]
  89. inds = torch.tensor(inds).long().cuda()
  90. inds2 = torch.tensor(inds2).long().cuda()
  91. return inds, inds2
  92. # indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
  93. # return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
  94. def dice_loss(inputs, targets, num_masks=None, threshold=0.0, topk_mask=None):
  95. """
  96. Compute the DICE loss, similar to generalized IOU for masks
  97. Args:
  98. inputs: A float tensor of arbitrary shape.
  99. The predictions for each example.
  100. targets: A float tensor with the same shape as inputs. Stores the binary
  101. classification label for each element in inputs
  102. (0 for the negative class and 1 for the positive class).
  103. 1. norm the input and the target to [0, 1] with sigmoid
  104. 2. binarize the target
  105. 3. compute dice loss
  106. """
  107. if num_masks is None:
  108. num_masks = inputs.size(1)
  109. if topk_mask is not None:
  110. ### [bs, k, nm] * [bs, k, 1], filter the masked clusters
  111. inputs = inputs * topk_mask.unsqueeze(-1)
  112. targets = targets * topk_mask.unsqueeze(-1)
  113. inputs = inputs.flatten(1)
  114. targets = targets.flatten(1)
  115. numerator = 2 * (inputs * targets).sum(-1)
  116. denominator = inputs.sum(-1) + targets.sum(-1)
  117. loss = 1 - (numerator + 1) / (denominator + 1)
  118. return loss.sum() / num_masks
  119. def get_logits(dense_feat_1, selected_feat_2, logit_scale):
  120. # logit_scale_dense = self.logit_scale.exp()
  121. logit_scale_dense = torch.clamp(logit_scale.exp(), max=100)
  122. i, j, k = dense_feat_1.shape
  123. l, m, k = selected_feat_2.shape
  124. dense_feat_1 = dense_feat_1.reshape(-1, k)
  125. selected_feat_2 = selected_feat_2.reshape(-1, k)
  126. final_logits_1 = logit_scale_dense * dense_feat_1 @ selected_feat_2.t()
  127. final_logits_1 = final_logits_1.reshape(i, j, l, m).permute(0,2,1,3)
  128. return final_logits_1
  129. def sim_matrix(a, b, eps=1e-8):
  130. """
  131. added eps for numerical stability
  132. """
  133. a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
  134. a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
  135. b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
  136. sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
  137. return sim_mt
  138. class NormSoftmaxLoss(nn.Module):
  139. def __init__(self, temperature=0.05):
  140. super().__init__()
  141. self.temperature = temperature
  142. def forward(self, x):
  143. i_logsm = F.log_softmax(x/self.temperature, dim=1)
  144. j_logsm = F.log_softmax(x.t()/self.temperature, dim=1)
  145. # sum over positives
  146. idiag = torch.diag(i_logsm)
  147. loss_i = idiag.sum() / len(idiag)
  148. jdiag = torch.diag(j_logsm)
  149. loss_j = jdiag.sum() / len(jdiag)
  150. return - loss_i - loss_j