misc.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # Modified by Jilan Xu
  13. # -------------------------------------------------------------------------
  14. import math
  15. import torch.nn.functional as F
  16. from ipdb import set_trace
  17. class Result:
  18. def __init__(self, as_dict=False):
  19. if as_dict:
  20. self.outs = {}
  21. else:
  22. self.outs = []
  23. @property
  24. def as_dict(self):
  25. return isinstance(self.outs, dict)
  26. def append(self, element, name=None):
  27. if self.as_dict:
  28. assert name is not None
  29. self.outs[name] = element
  30. else:
  31. self.outs.append(element)
  32. def update(self, **kwargs):
  33. if self.as_dict:
  34. self.outs.update(**kwargs)
  35. else:
  36. for v in kwargs.values():
  37. self.outs.append(v)
  38. def as_output(self):
  39. if self.as_dict:
  40. return self.outs
  41. else:
  42. return tuple(self.outs)
  43. def as_return(self):
  44. outs = self.as_output()
  45. if self.as_dict:
  46. return outs
  47. if len(outs) == 1:
  48. return outs[0]
  49. return outs
  50. def interpolate_pos_encoding(pos_embed, H, W):
  51. num_patches = H * W
  52. ##### problems might occur here######,
  53. ##### N = pos.embed.shape[0] and num_patches could be N - 1
  54. if pos_embed.ndim == 2:
  55. pos_embed = pos_embed.unsqueeze(0)
  56. N = pos_embed.shape[1]
  57. # N = pos_embed.shape[1]
  58. if num_patches == N and W == H:
  59. return pos_embed
  60. patch_pos_embed = pos_embed
  61. dim = pos_embed.shape[-1]
  62. patch_pos_embed = F.interpolate(
  63. patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
  64. size=(H, W),
  65. mode='bicubic',
  66. align_corners=False)
  67. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  68. return patch_pos_embed