misc.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. import collections.abc
  14. from collections import OrderedDict
  15. import torch
  16. import torch.distributed as dist
  17. from datasets import template_meta
  18. def reduce_tensor(tensor):
  19. rt = tensor.clone()
  20. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  21. rt /= dist.get_world_size()
  22. return rt
  23. def get_grad_norm(parameters, norm_type=2):
  24. if isinstance(parameters, torch.Tensor):
  25. parameters = [parameters]
  26. parameters = list(filter(lambda p: p.grad is not None, parameters))
  27. norm_type = float(norm_type)
  28. total_norm = 0
  29. for p in parameters:
  30. param_norm = p.grad.data.norm(norm_type)
  31. total_norm += param_norm.item()**norm_type
  32. total_norm = total_norm**(1. / norm_type)
  33. return total_norm
  34. def get_batch_size(data):
  35. if isinstance(data, torch.Tensor):
  36. return data.size(0)
  37. elif isinstance(data, collections.abc.Mapping):
  38. return get_batch_size(data[next(iter(data))])
  39. elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
  40. # check to make sure that the elements in batch have consistent size
  41. it = iter(data)
  42. return get_batch_size(next(it))
  43. raise TypeError
  44. def data2cuda(data):
  45. if isinstance(data, torch.Tensor):
  46. batch = data.cuda(non_blocking=True)
  47. return batch
  48. elif isinstance(data, collections.abc.Mapping):
  49. return {key: data2cuda(data[key]) for key in data}
  50. elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
  51. return [data2cuda(d) for d in data]
  52. else:
  53. raise TypeError
  54. def parse_losses(losses):
  55. log_vars = OrderedDict()
  56. for loss_name, loss_value in losses.items():
  57. if isinstance(loss_value, torch.Tensor):
  58. log_vars[loss_name] = loss_value.mean()
  59. elif isinstance(loss_value, list):
  60. log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
  61. else:
  62. raise TypeError(f'{loss_name} is not a tensor or list of tensors')
  63. loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
  64. return loss, log_vars
  65. def build_dataset_class_tokens(text_transform, template_set, classnames):
  66. tokens = []
  67. templates = template_meta[template_set]
  68. for classname in classnames:
  69. # format with class
  70. tokens.append(torch.stack([text_transform(template.format(classname)) for template in templates]))
  71. # [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
  72. tokens = torch.stack(tokens)
  73. return tokens