comm.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. """
  2. This file contains primitives for multi-gpu communication.
  3. This is useful when doing distributed training.
  4. """
  5. import pickle
  6. import torch
  7. import torch.distributed as dist
  8. def get_world_size():
  9. if not dist.is_available():
  10. return 1
  11. if not dist.is_initialized():
  12. return 1
  13. return dist.get_world_size()
  14. def get_rank():
  15. if not dist.is_available():
  16. return 0
  17. if not dist.is_initialized():
  18. return 0
  19. return dist.get_rank()
  20. def is_main_process():
  21. return get_rank() == 0
  22. def synchronize():
  23. """
  24. Helper function to synchronize (barrier) among all processes when
  25. using distributed training
  26. """
  27. if not dist.is_available():
  28. return
  29. if not dist.is_initialized():
  30. return
  31. world_size = dist.get_world_size()
  32. if world_size == 1:
  33. return
  34. dist.barrier()
  35. def all_gather(data):
  36. """
  37. Run all_gather on arbitrary picklable data (not necessarily tensors)
  38. Args:
  39. data: any picklable object
  40. Returns:
  41. list[data]: list of data gathered from each rank
  42. """
  43. world_size = get_world_size()
  44. if world_size == 1:
  45. return [data]
  46. # serialized to a Tensor
  47. buffer = pickle.dumps(data)
  48. storage = torch.ByteStorage.from_buffer(buffer)
  49. tensor = torch.ByteTensor(storage).to("cuda")
  50. # obtain Tensor size of each rank
  51. local_size = torch.IntTensor([tensor.numel()]).to("cuda")
  52. size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
  53. dist.all_gather(size_list, local_size)
  54. size_list = [int(size.item()) for size in size_list]
  55. max_size = max(size_list)
  56. # receiving Tensor from all ranks
  57. # we pad the tensor because torch all_gather does not support
  58. # gathering tensors of different shapes
  59. tensor_list = []
  60. for _ in size_list:
  61. tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
  62. if local_size != max_size:
  63. padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
  64. tensor = torch.cat((tensor, padding), dim=0)
  65. dist.all_gather(tensor_list, tensor)
  66. data_list = []
  67. for size, tensor in zip(size_list, tensor_list):
  68. buffer = tensor.cpu().numpy().tobytes()[:size]
  69. data_list.append(pickle.loads(buffer))
  70. return data_list
  71. def reduce_dict(input_dict, average=True):
  72. """
  73. Args:
  74. input_dict (dict): all the values will be reduced
  75. average (bool): whether to do average or sum
  76. Reduce the values in the dictionary from all processes so that process with rank
  77. 0 has the averaged results. Returns a dict with the same fields as
  78. input_dict, after reduction.
  79. """
  80. world_size = get_world_size()
  81. if world_size < 2:
  82. return input_dict
  83. with torch.no_grad():
  84. names = []
  85. values = []
  86. # sort the keys so that they are consistent across processes
  87. for k in sorted(input_dict.keys()):
  88. names.append(k)
  89. values.append(input_dict[k])
  90. values = torch.stack(values, dim=0)
  91. dist.reduce(values, dst=0)
  92. if dist.get_rank() == 0 and average:
  93. # only main process gets accumulated, so only divide by
  94. # world_size in this case
  95. values /= world_size
  96. reduced_dict = {k: v for k, v in zip(names, values)}
  97. return reduced_dict