misc.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # -------------------------------------------------------------------------
  10. import math
  11. import torch.nn.functional as F
  12. class Result:
  13. def __init__(self, as_dict=False):
  14. if as_dict:
  15. self.outs = {}
  16. else:
  17. self.outs = []
  18. @property
  19. def as_dict(self):
  20. return isinstance(self.outs, dict)
  21. def append(self, element, name=None):
  22. if self.as_dict:
  23. assert name is not None
  24. self.outs[name] = element
  25. else:
  26. self.outs.append(element)
  27. def update(self, **kwargs):
  28. if self.as_dict:
  29. self.outs.update(**kwargs)
  30. else:
  31. for v in kwargs.values():
  32. self.outs.append(v)
  33. def as_output(self):
  34. if self.as_dict:
  35. return self.outs
  36. else:
  37. return tuple(self.outs)
  38. def as_return(self):
  39. outs = self.as_output()
  40. if self.as_dict:
  41. return outs
  42. if len(outs) == 1:
  43. return outs[0]
  44. return outs
  45. def interpolate_pos_encoding(pos_embed, H, W):
  46. num_patches = H * W
  47. N = pos_embed.shape[1]
  48. if num_patches == N and W == H:
  49. return pos_embed
  50. patch_pos_embed = pos_embed
  51. dim = pos_embed.shape[-1]
  52. patch_pos_embed = F.interpolate(
  53. patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
  54. size=(H, W),
  55. mode='bicubic',
  56. align_corners=False)
  57. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  58. return patch_pos_embed