# ------------------------------------------------------------------------- # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual # property and proprietary rights in and to this software, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this software and related documentation # without an express license agreement from NVIDIA CORPORATION is strictly # prohibited. # # Written by Jiarui Xu # Modified by Jilan Xu # ------------------------------------------------------------------------- import math import torch.nn.functional as F from ipdb import set_trace class Result: def __init__(self, as_dict=False): if as_dict: self.outs = {} else: self.outs = [] @property def as_dict(self): return isinstance(self.outs, dict) def append(self, element, name=None): if self.as_dict: assert name is not None self.outs[name] = element else: self.outs.append(element) def update(self, **kwargs): if self.as_dict: self.outs.update(**kwargs) else: for v in kwargs.values(): self.outs.append(v) def as_output(self): if self.as_dict: return self.outs else: return tuple(self.outs) def as_return(self): outs = self.as_output() if self.as_dict: return outs if len(outs) == 1: return outs[0] return outs def interpolate_pos_encoding(pos_embed, H, W): num_patches = H * W ##### problems might occur here######, ##### N = pos.embed.shape[0] and num_patches could be N - 1 if pos_embed.ndim == 2: pos_embed = pos_embed.unsqueeze(0) N = pos_embed.shape[1] # N = pos_embed.shape[1] if num_patches == N and W == H: return pos_embed patch_pos_embed = pos_embed dim = pos_embed.shape[-1] patch_pos_embed = F.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), size=(H, W), mode='bicubic', align_corners=False) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed