vision_transformer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # Copyright (c) ByteDance, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # -------------------------------------------------------------------------
  7. # Modified by Jilan Xu
  8. # -------------------------------------------------------------------------
  9. """
  10. Mostly copy-paste from DINO and timm library:
  11. https://github.com/facebookresearch/dino
  12. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  13. """
  14. import math
  15. import torch
  16. import torch.nn as nn
  17. from functools import partial
  18. from timm.models.registry import register_model
  19. def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
  20. # type: (Tensor, float, float, float, float) -> Tensor
  21. return _no_grad_trunc_normal_(tensor, mean, std, a, b)
  22. def drop_path(x, drop_prob: float = 0., training: bool = False):
  23. if drop_prob == 0. or not training:
  24. return x
  25. keep_prob = 1 - drop_prob
  26. shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  27. random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
  28. random_tensor.floor_() # binarize
  29. output = x.div(keep_prob) * random_tensor
  30. return output
  31. class DropPath(nn.Module):
  32. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  33. """
  34. def __init__(self, drop_prob=None):
  35. super(DropPath, self).__init__()
  36. self.drop_prob = drop_prob
  37. def forward(self, x):
  38. return drop_path(x, self.drop_prob, self.training)
  39. class Mlp(nn.Module):
  40. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  41. super().__init__()
  42. out_features = out_features or in_features
  43. hidden_features = hidden_features or in_features
  44. self.fc1 = nn.Linear(in_features, hidden_features)
  45. self.act = act_layer()
  46. self.fc2 = nn.Linear(hidden_features, out_features)
  47. self.drop = nn.Dropout(drop)
  48. def forward(self, x):
  49. x = self.fc1(x)
  50. x = self.act(x)
  51. x = self.drop(x)
  52. x = self.fc2(x)
  53. x = self.drop(x)
  54. return x
  55. class Attention(nn.Module):
  56. def __init__(self,
  57. dim,
  58. num_heads=8,
  59. qkv_bias=False,
  60. qk_scale=None,
  61. attn_drop=0.,
  62. proj_drop=0.):
  63. super().__init__()
  64. self.num_heads = num_heads
  65. head_dim = dim // num_heads
  66. self.scale = qk_scale or head_dim ** -0.5
  67. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  68. self.attn_drop = nn.Dropout(attn_drop)
  69. self.proj = nn.Linear(dim, dim)
  70. self.proj_drop = nn.Dropout(proj_drop)
  71. def forward(self, x):
  72. B, N, C = x.shape
  73. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  74. q, k, v = qkv[0], qkv[1], qkv[2]
  75. attn = (q @ k.transpose(-2, -1)) * self.scale
  76. attn = attn.softmax(dim=-1)
  77. attn = self.attn_drop(attn)
  78. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  79. x = self.proj(x)
  80. x = self.proj_drop(x)
  81. return x, attn
  82. class Block(nn.Module):
  83. def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0.,
  84. attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, init_values=0):
  85. super().__init__()
  86. self.norm1 = norm_layer(dim)
  87. self.attn = Attention(
  88. dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  89. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  90. self.norm2 = norm_layer(dim)
  91. mlp_hidden_dim = int(dim * mlp_ratio)
  92. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  93. if init_values > 0:
  94. self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
  95. self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
  96. else:
  97. self.gamma_1, self.gamma_2 = None, None
  98. def forward(self, x, return_attention=False):
  99. y, attn = self.attn(self.norm1(x))
  100. if return_attention:
  101. return attn
  102. if self.gamma_1 is None:
  103. x = x + self.drop_path(y)
  104. x = x + self.drop_path(self.mlp(self.norm2(x)))
  105. else:
  106. x = x + self.drop_path(self.gamma_1 * y)
  107. x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
  108. return x
  109. class PatchEmbed(nn.Module):
  110. """ Image to Patch Embedding
  111. """
  112. def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
  113. super().__init__()
  114. num_patches = (img_size // patch_size) * (img_size // patch_size)
  115. self.img_size = img_size
  116. self.patch_size = patch_size
  117. self.num_patches = num_patches
  118. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  119. def forward(self, x):
  120. B, C, H, W = x.shape
  121. return self.proj(x)
  122. class VisionTransformer(nn.Module):
  123. """ Vision Transformer """
  124. def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
  125. num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
  126. drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), return_all_tokens=False,
  127. init_values=0, use_mean_pooling=False, masked_im_modeling=False):
  128. super().__init__()
  129. self.num_features = self.embed_dim = embed_dim
  130. self.return_all_tokens = return_all_tokens
  131. self.patch_embed = PatchEmbed(
  132. img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
  133. num_patches = self.patch_embed.num_patches
  134. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  135. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
  136. self.pos_drop = nn.Dropout(p=drop_rate)
  137. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
  138. self.blocks = nn.ModuleList([
  139. Block(
  140. dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
  141. drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
  142. init_values=init_values)
  143. for i in range(depth)])
  144. self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
  145. self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
  146. # Classifier head
  147. self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  148. trunc_normal_(self.pos_embed, std=.02)
  149. trunc_normal_(self.cls_token, std=.02)
  150. self.apply(self._init_weights)
  151. # masked image modeling
  152. print('whether use masked im modeling', masked_im_modeling)
  153. self.masked_im_modeling = masked_im_modeling
  154. if masked_im_modeling:
  155. self.masked_embed = nn.Parameter(torch.zeros(1, embed_dim))
  156. def _init_weights(self, m):
  157. if isinstance(m, nn.Linear):
  158. trunc_normal_(m.weight, std=.02)
  159. if isinstance(m, nn.Linear) and m.bias is not None:
  160. nn.init.constant_(m.bias, 0)
  161. elif isinstance(m, nn.LayerNorm):
  162. nn.init.constant_(m.bias, 0)
  163. nn.init.constant_(m.weight, 1.0)
  164. def interpolate_pos_encoding(self, x, w, h):
  165. npatch = x.shape[1] - 1
  166. N = self.pos_embed.shape[1] - 1
  167. if npatch == N and w == h:
  168. return self.pos_embed
  169. class_pos_embed = self.pos_embed[:, 0]
  170. patch_pos_embed = self.pos_embed[:, 1:]
  171. dim = x.shape[-1]
  172. w0 = w // self.patch_embed.patch_size
  173. h0 = h // self.patch_embed.patch_size
  174. # we add a small number to avoid floating point error in the interpolation
  175. # see discussion at https://github.com/facebookresearch/dino/issues/8
  176. w0, h0 = w0 + 0.1, h0 + 0.1
  177. patch_pos_embed = nn.functional.interpolate(
  178. patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
  179. scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
  180. mode='bicubic',
  181. )
  182. assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
  183. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  184. return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
  185. def prepare_tokens(self, x, mask=None):
  186. B, nc, w, h = x.shape
  187. # patch linear embedding
  188. x = self.patch_embed(x)
  189. # mask image modeling
  190. if mask is not None:
  191. x = self.mask_model(x, mask)
  192. x = x.flatten(2).transpose(1, 2)
  193. # add the [CLS] token to the embed patch tokens
  194. cls_tokens = self.cls_token.expand(B, -1, -1)
  195. x = torch.cat((cls_tokens, x), dim=1)
  196. # add positional encoding to each token
  197. x = x + self.interpolate_pos_encoding(x, w, h)
  198. return self.pos_drop(x)
  199. def forward(self, x, return_all_tokens=None, mask=None):
  200. # mim
  201. if self.masked_im_modeling:
  202. assert mask is not None
  203. #print('whats up here: ' , x.shape, mask.shape)
  204. x = self.prepare_tokens(x, mask=mask)
  205. else:
  206. x = self.prepare_tokens(x)
  207. for blk in self.blocks:
  208. x = blk(x)
  209. x = self.norm(x)
  210. if self.fc_norm is not None:
  211. x[:, 0] = self.fc_norm(x[:, 1:, :].mean(1))
  212. return_all_tokens = self.return_all_tokens if \
  213. return_all_tokens is None else return_all_tokens
  214. if return_all_tokens:
  215. return x
  216. return x[:, 0]
  217. def get_last_selfattention(self, x):
  218. x = self.prepare_tokens(x)
  219. for i, blk in enumerate(self.blocks):
  220. if i < len(self.blocks) - 1:
  221. x = blk(x)
  222. else:
  223. # return attention of the last block
  224. return blk(x, return_attention=True)
  225. def get_intermediate_layers(self, x, n=1):
  226. x = self.prepare_tokens(x)
  227. # we return the output tokens from the `n` last blocks
  228. output = []
  229. for i, blk in enumerate(self.blocks):
  230. x = blk(x)
  231. if len(self.blocks) - i <= n:
  232. output.append(self.norm(x))
  233. return output
  234. def get_num_layers(self):
  235. return len(self.blocks)
  236. def mask_model(self, x, mask):
  237. x.permute(0, 2, 3, 1)[mask, :] = self.masked_embed.to(x.dtype)
  238. return x
  239. def vit_mini(patch_size=16, **kwargs):
  240. model = VisionTransformer(
  241. patch_size=patch_size, embed_dim=384, depth=4, num_heads=3, mlp_ratio=4,
  242. qkv_bias=True, **kwargs)
  243. return model
  244. def vit_tiny(image_size=[224], patch_size=16, **kwargs):
  245. model = VisionTransformer(
  246. image_size=image_size, patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
  247. qkv_bias=True, **kwargs)
  248. return model
  249. def vit_small(image_size=[224], patch_size=16, **kwargs):
  250. model = VisionTransformer(
  251. image_size=image_size, patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
  252. qkv_bias=True, **kwargs)
  253. return model
  254. def vit_base(image_size=[224], patch_size=16, **kwargs):
  255. model = VisionTransformer(
  256. image_size=image_size, patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
  257. qkv_bias=True, **kwargs)
  258. return model
  259. def vit_large(image_size=[224], patch_size=16, **kwargs):
  260. model = VisionTransformer(
  261. image_size=image_size, patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
  262. qkv_bias=True, **kwargs)
  263. return model