misc.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # -------------------------------------------------------------------------
  10. import collections.abc
  11. from collections import OrderedDict
  12. import torch
  13. import torch.distributed as dist
  14. from datasets import template_meta
  15. def reduce_tensor(tensor):
  16. rt = tensor.clone()
  17. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  18. rt /= dist.get_world_size()
  19. return rt
  20. def get_grad_norm(parameters, norm_type=2):
  21. if isinstance(parameters, torch.Tensor):
  22. parameters = [parameters]
  23. parameters = list(filter(lambda p: p.grad is not None, parameters))
  24. norm_type = float(norm_type)
  25. total_norm = 0
  26. for p in parameters:
  27. param_norm = p.grad.data.norm(norm_type)
  28. total_norm += param_norm.item()**norm_type
  29. total_norm = total_norm**(1. / norm_type)
  30. return total_norm
  31. def get_batch_size(data):
  32. if isinstance(data, torch.Tensor):
  33. return data.size(0)
  34. elif isinstance(data, collections.abc.Mapping):
  35. return get_batch_size(data[next(iter(data))])
  36. elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
  37. # check to make sure that the elements in batch have consistent size
  38. it = iter(data)
  39. return get_batch_size(next(it))
  40. raise TypeError
  41. def data2cuda(data):
  42. if isinstance(data, torch.Tensor):
  43. batch = data.cuda(non_blocking=True)
  44. return batch
  45. elif isinstance(data, collections.abc.Mapping):
  46. return {key: data2cuda(data[key]) for key in data}
  47. elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
  48. return [data2cuda(d) for d in data]
  49. else:
  50. raise TypeError
  51. def parse_losses(losses):
  52. log_vars = OrderedDict()
  53. for loss_name, loss_value in losses.items():
  54. if isinstance(loss_value, torch.Tensor):
  55. log_vars[loss_name] = loss_value.mean()
  56. elif isinstance(loss_value, list):
  57. log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
  58. else:
  59. raise TypeError(f'{loss_name} is not a tensor or list of tensors')
  60. loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
  61. return loss, log_vars
  62. def build_dataset_class_tokens(text_transform, template_set, classnames):
  63. tokens = []
  64. templates = template_meta[template_set]
  65. for classname in classnames:
  66. # format with class
  67. tokens.append(torch.stack([text_transform(template.format(classname)) for template in templates]))
  68. # [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
  69. tokens = torch.stack(tokens)
  70. return tokens