1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- # -------------------------------------------------------------------------
- # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
- #
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
- # property and proprietary rights in and to this software, related
- # documentation and any modifications thereto. Any use, reproduction,
- # disclosure or distribution of this software and related documentation
- # without an express license agreement from NVIDIA CORPORATION is strictly
- # prohibited.
- #
- # Written by Jiarui Xu
- # Modified by Jilan Xu
- # -------------------------------------------------------------------------
- import math
- import torch.nn.functional as F
- from ipdb import set_trace
- class Result:
- def __init__(self, as_dict=False):
- if as_dict:
- self.outs = {}
- else:
- self.outs = []
- @property
- def as_dict(self):
- return isinstance(self.outs, dict)
- def append(self, element, name=None):
- if self.as_dict:
- assert name is not None
- self.outs[name] = element
- else:
- self.outs.append(element)
- def update(self, **kwargs):
- if self.as_dict:
- self.outs.update(**kwargs)
- else:
- for v in kwargs.values():
- self.outs.append(v)
- def as_output(self):
- if self.as_dict:
- return self.outs
- else:
- return tuple(self.outs)
- def as_return(self):
- outs = self.as_output()
- if self.as_dict:
- return outs
- if len(outs) == 1:
- return outs[0]
- return outs
- def interpolate_pos_encoding(pos_embed, H, W):
- num_patches = H * W
- ##### problems might occur here######,
- ##### N = pos.embed.shape[0] and num_patches could be N - 1
- if pos_embed.ndim == 2:
- pos_embed = pos_embed.unsqueeze(0)
- N = pos_embed.shape[1]
- # N = pos_embed.shape[1]
- if num_patches == N and W == H:
- return pos_embed
- patch_pos_embed = pos_embed
- dim = pos_embed.shape[-1]
- patch_pos_embed = F.interpolate(
- patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
- size=(H, W),
- mode='bicubic',
- align_corners=False)
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
- return patch_pos_embed
|