misc.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. import math
  14. import torch.nn.functional as F
  15. class Result:
  16. def __init__(self, as_dict=False):
  17. if as_dict:
  18. self.outs = {}
  19. else:
  20. self.outs = []
  21. @property
  22. def as_dict(self):
  23. return isinstance(self.outs, dict)
  24. def append(self, element, name=None):
  25. if self.as_dict:
  26. assert name is not None
  27. self.outs[name] = element
  28. else:
  29. self.outs.append(element)
  30. def update(self, **kwargs):
  31. if self.as_dict:
  32. self.outs.update(**kwargs)
  33. else:
  34. for v in kwargs.values():
  35. self.outs.append(v)
  36. def as_output(self):
  37. if self.as_dict:
  38. return self.outs
  39. else:
  40. return tuple(self.outs)
  41. def as_return(self):
  42. outs = self.as_output()
  43. if self.as_dict:
  44. return outs
  45. if len(outs) == 1:
  46. return outs[0]
  47. return outs
  48. def interpolate_pos_encoding(pos_embed, H, W):
  49. num_patches = H * W
  50. N = pos_embed.shape[1]
  51. if num_patches == N and W == H:
  52. return pos_embed
  53. patch_pos_embed = pos_embed
  54. dim = pos_embed.shape[-1]
  55. patch_pos_embed = F.interpolate(
  56. patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
  57. size=(H, W),
  58. mode='bicubic',
  59. align_corners=False)
  60. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  61. return patch_pos_embed