1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- # -------------------------------------------------------------------------
- # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
- #
- # This work is made available under the Nvidia Source Code License.
- # To view a copy of this license, visit
- # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
- #
- # Written by Jiarui Xu
- # -------------------------------------------------------------------------
- import math
- import torch.nn.functional as F
- class Result:
- def __init__(self, as_dict=False):
- if as_dict:
- self.outs = {}
- else:
- self.outs = []
- @property
- def as_dict(self):
- return isinstance(self.outs, dict)
- def append(self, element, name=None):
- if self.as_dict:
- assert name is not None
- self.outs[name] = element
- else:
- self.outs.append(element)
- def update(self, **kwargs):
- if self.as_dict:
- self.outs.update(**kwargs)
- else:
- for v in kwargs.values():
- self.outs.append(v)
- def as_output(self):
- if self.as_dict:
- return self.outs
- else:
- return tuple(self.outs)
- def as_return(self):
- outs = self.as_output()
- if self.as_dict:
- return outs
- if len(outs) == 1:
- return outs[0]
- return outs
- def interpolate_pos_encoding(pos_embed, H, W):
- num_patches = H * W
- N = pos_embed.shape[1]
- if num_patches == N and W == H:
- return pos_embed
- patch_pos_embed = pos_embed
- dim = pos_embed.shape[-1]
- patch_pos_embed = F.interpolate(
- patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
- size=(H, W),
- mode='bicubic',
- align_corners=False)
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
- return patch_pos_embed
|