misc.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. # Modified by Jilan Xu
  14. # -------------------------------------------------------------------------
  15. import collections.abc
  16. from collections import OrderedDict
  17. import json
  18. import cv2
  19. import numpy as np
  20. from PIL import Image
  21. from ipdb import set_trace
  22. import scipy
  23. import torch
  24. import torch.distributed as dist
  25. from datasets import template_meta
  26. def reduce_tensor(tensor):
  27. rt = tensor.clone()
  28. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  29. rt /= dist.get_world_size()
  30. return rt
  31. def momentum_update(model_on, model_off, coeff):
  32. for params_on, params_off in zip(model_on.parameters(), model_off.parameters()):
  33. params_off.data = coeff * params_off.data + (1 - coeff) * params_on.data
  34. def get_grad_norm(parameters, norm_type=2):
  35. if isinstance(parameters, torch.Tensor):
  36. parameters = [parameters]
  37. parameters = list(filter(lambda p: p.grad is not None, parameters))
  38. norm_type = float(norm_type)
  39. total_norm = 0
  40. for p in parameters:
  41. param_norm = p.grad.data.norm(norm_type)
  42. total_norm += param_norm.item()**norm_type
  43. total_norm = total_norm**(1. / norm_type)
  44. return total_norm
  45. def get_batch_size(data):
  46. if isinstance(data, torch.Tensor):
  47. return data.size(0)
  48. elif isinstance(data, collections.abc.Mapping):
  49. return get_batch_size(data[next(iter(data))])
  50. elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
  51. # check to make sure that the elements in batch have consistent size
  52. it = iter(data)
  53. return get_batch_size(next(it))
  54. raise TypeError
  55. def data2cuda(data):
  56. if isinstance(data, torch.Tensor):
  57. batch = data.cuda(non_blocking=True)
  58. return batch
  59. elif isinstance(data, collections.abc.Mapping):
  60. return {key: data2cuda(data[key]) for key in data}
  61. elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
  62. return [data2cuda(d) for d in data]
  63. else:
  64. raise TypeError
  65. def parse_losses(losses):
  66. log_vars = OrderedDict()
  67. for loss_name, loss_value in losses.items():
  68. if isinstance(loss_value, torch.Tensor):
  69. log_vars[loss_name] = loss_value.mean()
  70. elif isinstance(loss_value, list):
  71. log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
  72. else:
  73. raise TypeError(f'{loss_name} is not a tensor or list of tensors')
  74. loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
  75. return loss, log_vars
  76. def build_dataset_class_tokens(text_transform, template_set, classnames):
  77. tokens = []
  78. templates = template_meta[template_set]
  79. for classname in classnames:
  80. # format with class
  81. tokens.append(torch.stack([text_transform(template.format(classname)) for template in templates]))
  82. # [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
  83. tokens = torch.stack(tokens)
  84. return tokens
  85. def build_dataset_class_lists(template_set, classnames):
  86. tokens = []
  87. templates = template_meta[template_set]
  88. for classname in classnames:
  89. # format with class
  90. for template in templates:
  91. tokens.append(template.format(classname))
  92. # [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
  93. # tokens = torch.stack(tokens)
  94. return tokens
  95. def cdist_(x, metric='euclidean'):
  96. assert len(x.shape) == 3
  97. if metric != 'JS':
  98. x_ = torch.split(x, 1, dim=0) # tuple
  99. return np.mean(tuple(map(lambda a: scipy.spatial.distance.cdist(a.squeeze(), a.squeeze(), metric).mean(), x_)))
  100. else:
  101. softmax = torch.nn.Softmax(dim=1)
  102. x_ = torch.split(softmax(x), 1, dim=0) # tuple
  103. return np.mean(tuple(map(lambda a: scipy.spatial.distance.cdist(a.squeeze(), a.squeeze(), metric).mean(), x_)))
  104. pass
  105. pass