group_vit.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. from collections import OrderedDict
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. import torch.utils.checkpoint as checkpoint
  18. from einops import rearrange
  19. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  20. from .builder import MODELS
  21. from .misc import Result, interpolate_pos_encoding
  22. class Mlp(nn.Module):
  23. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  24. super().__init__()
  25. out_features = out_features or in_features
  26. hidden_features = hidden_features or in_features
  27. self.fc1 = nn.Linear(in_features, hidden_features)
  28. self.act = act_layer()
  29. self.fc2 = nn.Linear(hidden_features, out_features)
  30. self.drop = nn.Dropout(drop)
  31. def forward(self, x):
  32. x = self.fc1(x)
  33. x = self.act(x)
  34. x = self.drop(x)
  35. x = self.fc2(x)
  36. x = self.drop(x)
  37. return x
  38. class MixerMlp(Mlp):
  39. def forward(self, x):
  40. return super().forward(x.transpose(1, 2)).transpose(1, 2)
  41. def hard_softmax(logits, dim):
  42. y_soft = logits.softmax(dim)
  43. # Straight through.
  44. index = y_soft.max(dim, keepdim=True)[1]
  45. y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
  46. ret = y_hard - y_soft.detach() + y_soft
  47. return ret
  48. def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> torch.Tensor:
  49. # _gumbels = (-torch.empty_like(
  50. # logits,
  51. # memory_format=torch.legacy_contiguous_format).exponential_().log()
  52. # ) # ~Gumbel(0,1)
  53. # more stable https://github.com/pytorch/pytorch/issues/41663
  54. gumbel_dist = torch.distributions.gumbel.Gumbel(
  55. torch.tensor(0., device=logits.device, dtype=logits.dtype),
  56. torch.tensor(1., device=logits.device, dtype=logits.dtype))
  57. gumbels = gumbel_dist.sample(logits.shape)
  58. gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
  59. y_soft = gumbels.softmax(dim)
  60. if hard:
  61. # Straight through.
  62. index = y_soft.max(dim, keepdim=True)[1]
  63. y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
  64. ret = y_hard - y_soft.detach() + y_soft
  65. else:
  66. # Reparametrization trick.
  67. ret = y_soft
  68. return ret
  69. class AssignAttention(nn.Module):
  70. def __init__(self,
  71. dim,
  72. num_heads=1,
  73. qkv_bias=False,
  74. qk_scale=None,
  75. attn_drop=0.,
  76. proj_drop=0.,
  77. hard=True,
  78. gumbel=False,
  79. gumbel_tau=1.,
  80. sum_assign=False,
  81. assign_eps=1.):
  82. super().__init__()
  83. self.num_heads = num_heads
  84. head_dim = dim // num_heads
  85. self.scale = qk_scale or head_dim**-0.5
  86. self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
  87. self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
  88. self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
  89. self.attn_drop = nn.Dropout(attn_drop)
  90. self.proj = nn.Linear(dim, dim)
  91. self.proj_drop = nn.Dropout(proj_drop)
  92. self.hard = hard
  93. self.gumbel = gumbel
  94. self.gumbel_tau = gumbel_tau
  95. self.sum_assign = sum_assign
  96. self.assign_eps = assign_eps
  97. def get_attn(self, attn, gumbel=None, hard=None):
  98. if gumbel is None:
  99. gumbel = self.gumbel
  100. if hard is None:
  101. hard = self.hard
  102. attn_dim = -2
  103. if gumbel and self.training:
  104. attn = gumbel_softmax(attn, dim=attn_dim, hard=hard, tau=self.gumbel_tau)
  105. else:
  106. if hard:
  107. attn = hard_softmax(attn, dim=attn_dim)
  108. else:
  109. attn = F.softmax(attn, dim=attn_dim)
  110. return attn
  111. def forward(self, query, key=None, *, value=None, return_attn=False):
  112. B, N, C = query.shape
  113. if key is None:
  114. key = query
  115. if value is None:
  116. value = key
  117. S = key.size(1)
  118. # [B, nh, N, C//nh]
  119. q = rearrange(self.q_proj(query), 'b n (h c)-> b h n c', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
  120. # [B, nh, S, C//nh]
  121. k = rearrange(self.k_proj(key), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
  122. # [B, nh, S, C//nh]
  123. v = rearrange(self.v_proj(value), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
  124. # [B, nh, N, S]
  125. raw_attn = (q @ k.transpose(-2, -1)) * self.scale
  126. attn = self.get_attn(raw_attn)
  127. if return_attn:
  128. hard_attn = attn.clone()
  129. soft_attn = self.get_attn(raw_attn, gumbel=False, hard=False)
  130. attn_dict = {'hard': hard_attn, 'soft': soft_attn}
  131. else:
  132. attn_dict = None
  133. if not self.sum_assign:
  134. attn = attn / (attn.sum(dim=-1, keepdim=True) + self.assign_eps)
  135. attn = self.attn_drop(attn)
  136. assert attn.shape == (B, self.num_heads, N, S)
  137. # [B, nh, N, C//nh] <- [B, nh, N, S] @ [B, nh, S, C//nh]
  138. out = rearrange(attn @ v, 'b h n c -> b n (h c)', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
  139. out = self.proj(out)
  140. out = self.proj_drop(out)
  141. return out, attn_dict
  142. def extra_repr(self):
  143. return f'num_heads: {self.num_heads}, \n' \
  144. f'hard: {self.hard}, \n' \
  145. f'gumbel: {self.gumbel}, \n' \
  146. f'sum_assign={self.sum_assign}, \n' \
  147. f'gumbel_tau: {self.gumbel_tau}, \n' \
  148. f'assign_eps: {self.assign_eps}'
  149. class GroupingBlock(nn.Module):
  150. """Grouping Block to group similar segments together.
  151. Args:
  152. dim (int): Dimension of the input.
  153. out_dim (int): Dimension of the output.
  154. num_heads (int): Number of heads in the grouping attention.
  155. num_output_group (int): Number of output groups.
  156. norm_layer (nn.Module): Normalization layer to use.
  157. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  158. hard (bool): Whether to use hard or soft assignment. Default: True
  159. gumbel (bool): Whether to use gumbel softmax. Default: True
  160. sum_assign (bool): Whether to sum assignment or average. Default: False
  161. assign_eps (float): Epsilon to avoid divide by zero. Default: 1
  162. gum_tau (float): Temperature for gumbel softmax. Default: 1
  163. """
  164. def __init__(self,
  165. *,
  166. dim,
  167. out_dim,
  168. num_heads,
  169. num_group_token,
  170. num_output_group,
  171. norm_layer,
  172. mlp_ratio=(0.5, 4.0),
  173. hard=True,
  174. gumbel=True,
  175. sum_assign=False,
  176. assign_eps=1.,
  177. gumbel_tau=1.):
  178. super(GroupingBlock, self).__init__()
  179. self.dim = dim
  180. self.hard = hard
  181. self.gumbel = gumbel
  182. self.sum_assign = sum_assign
  183. self.num_output_group = num_output_group
  184. # norm on group_tokens
  185. self.norm_tokens = norm_layer(dim)
  186. tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
  187. self.mlp_inter = Mlp(num_group_token, tokens_dim, num_output_group)
  188. self.norm_post_tokens = norm_layer(dim)
  189. # norm on x
  190. self.norm_x = norm_layer(dim)
  191. self.pre_assign_attn = CrossAttnBlock(
  192. dim=dim, num_heads=num_heads, mlp_ratio=4, qkv_bias=True, norm_layer=norm_layer, post_norm=True)
  193. self.assign = AssignAttention(
  194. dim=dim,
  195. num_heads=1,
  196. qkv_bias=True,
  197. hard=hard,
  198. gumbel=gumbel,
  199. gumbel_tau=gumbel_tau,
  200. sum_assign=sum_assign,
  201. assign_eps=assign_eps)
  202. self.norm_new_x = norm_layer(dim)
  203. self.mlp_channels = Mlp(dim, channels_dim, out_dim)
  204. if out_dim is not None and dim != out_dim:
  205. self.reduction = nn.Sequential(norm_layer(dim), nn.Linear(dim, out_dim, bias=False))
  206. else:
  207. self.reduction = nn.Identity()
  208. def extra_repr(self):
  209. return f'hard={self.hard}, \n' \
  210. f'gumbel={self.gumbel}, \n' \
  211. f'sum_assign={self.sum_assign}, \n' \
  212. f'num_output_group={self.num_output_group}, \n '
  213. def project_group_token(self, group_tokens):
  214. """
  215. Args:
  216. group_tokens (torch.Tensor): group tokens, [B, S_1, C]
  217. inter_weight (torch.Tensor): [B, S_2, S_1], S_2 is the new number of
  218. group tokens, it's already softmaxed along dim=-1
  219. Returns:
  220. projected_group_tokens (torch.Tensor): [B, S_2, C]
  221. """
  222. # [B, S_2, C] <- [B, S_1, C]
  223. projected_group_tokens = self.mlp_inter(group_tokens.transpose(1, 2)).transpose(1, 2)
  224. projected_group_tokens = self.norm_post_tokens(projected_group_tokens)
  225. return projected_group_tokens
  226. def forward(self, x, group_tokens, return_attn=False):
  227. """
  228. Args:
  229. x (torch.Tensor): image tokens, [B, L, C]
  230. group_tokens (torch.Tensor): group tokens, [B, S_1, C]
  231. return_attn (bool): whether to return attention map
  232. Returns:
  233. new_x (torch.Tensor): [B, S_2, C], S_2 is the new number of
  234. group tokens
  235. """
  236. group_tokens = self.norm_tokens(group_tokens)
  237. x = self.norm_x(x)
  238. # [B, S_2, C]
  239. projected_group_tokens = self.project_group_token(group_tokens)
  240. projected_group_tokens = self.pre_assign_attn(projected_group_tokens, x)
  241. new_x, attn_dict = self.assign(projected_group_tokens, x, return_attn=return_attn)
  242. new_x += projected_group_tokens
  243. new_x = self.reduction(new_x) + self.mlp_channels(self.norm_new_x(new_x))
  244. return new_x, attn_dict
  245. class Attention(nn.Module):
  246. def __init__(self,
  247. dim,
  248. num_heads,
  249. out_dim=None,
  250. qkv_bias=False,
  251. qk_scale=None,
  252. attn_drop=0.,
  253. proj_drop=0.,
  254. qkv_fuse=False):
  255. super().__init__()
  256. if out_dim is None:
  257. out_dim = dim
  258. self.num_heads = num_heads
  259. head_dim = dim // num_heads
  260. self.scale = qk_scale or head_dim**-0.5
  261. self.qkv_fuse = qkv_fuse
  262. if qkv_fuse:
  263. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  264. else:
  265. self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
  266. self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
  267. self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
  268. self.attn_drop = nn.Dropout(attn_drop)
  269. self.proj = nn.Linear(dim, out_dim)
  270. self.proj_drop = nn.Dropout(proj_drop)
  271. def extra_repr(self):
  272. return f'num_heads={self.num_heads}, \n' \
  273. f'qkv_bias={self.scale}, \n' \
  274. f'qkv_fuse={self.qkv_fuse}'
  275. def forward(self, query, key=None, *, value=None, mask=None):
  276. if self.qkv_fuse:
  277. assert key is None
  278. assert value is None
  279. x = query
  280. B, N, C = x.shape
  281. S = N
  282. # [3, B, nh, N, C//nh]
  283. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  284. # [B, nh, N, C//nh]
  285. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  286. else:
  287. B, N, C = query.shape
  288. if key is None:
  289. key = query
  290. if value is None:
  291. value = key
  292. S = key.size(1)
  293. # [B, nh, N, C//nh]
  294. q = rearrange(self.q_proj(query), 'b n (h c)-> b h n c', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
  295. # [B, nh, S, C//nh]
  296. k = rearrange(self.k_proj(key), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
  297. # [B, nh, S, C//nh]
  298. v = rearrange(self.v_proj(value), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
  299. # [B, nh, N, S]
  300. attn = (q @ k.transpose(-2, -1)) * self.scale
  301. if mask is not None:
  302. attn = attn + mask.unsqueeze(dim=1)
  303. attn = attn.softmax(dim=-1)
  304. else:
  305. attn = attn.softmax(dim=-1)
  306. attn = self.attn_drop(attn)
  307. assert attn.shape == (B, self.num_heads, N, S)
  308. # [B, nh, N, C//nh] -> [B, N, C]
  309. # out = (attn @ v).transpose(1, 2).reshape(B, N, C)
  310. out = rearrange(attn @ v, 'b h n c -> b n (h c)', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
  311. out = self.proj(out)
  312. out = self.proj_drop(out)
  313. return out
  314. class CrossAttnBlock(nn.Module):
  315. def __init__(self,
  316. dim,
  317. num_heads,
  318. mlp_ratio=4.,
  319. qkv_bias=False,
  320. qk_scale=None,
  321. drop=0.,
  322. attn_drop=0.,
  323. drop_path=0.,
  324. act_layer=nn.GELU,
  325. norm_layer=nn.LayerNorm,
  326. post_norm=False):
  327. super().__init__()
  328. if post_norm:
  329. self.norm_post = norm_layer(dim)
  330. self.norm_q = nn.Identity()
  331. self.norm_k = nn.Identity()
  332. else:
  333. self.norm_q = norm_layer(dim)
  334. self.norm_k = norm_layer(dim)
  335. self.norm_post = nn.Identity()
  336. self.attn = Attention(
  337. dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  338. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  339. self.norm2 = norm_layer(dim)
  340. mlp_hidden_dim = int(dim * mlp_ratio)
  341. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  342. def forward(self, query, key, *, mask=None):
  343. x = query
  344. x = x + self.drop_path(self.attn(self.norm_q(query), self.norm_k(key), mask=mask))
  345. x = x + self.drop_path(self.mlp(self.norm2(x)))
  346. x = self.norm_post(x)
  347. return x
  348. class AttnBlock(nn.Module):
  349. def __init__(self,
  350. dim,
  351. num_heads,
  352. mlp_ratio=4.,
  353. qkv_bias=False,
  354. qk_scale=None,
  355. drop=0.,
  356. attn_drop=0.,
  357. drop_path=0.,
  358. act_layer=nn.GELU,
  359. norm_layer=nn.LayerNorm):
  360. super().__init__()
  361. self.norm1 = norm_layer(dim)
  362. self.attn = Attention(
  363. dim,
  364. num_heads=num_heads,
  365. qkv_bias=qkv_bias,
  366. qk_scale=qk_scale,
  367. attn_drop=attn_drop,
  368. proj_drop=drop,
  369. qkv_fuse=True)
  370. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  371. self.norm2 = norm_layer(dim)
  372. mlp_hidden_dim = int(dim * mlp_ratio)
  373. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  374. def forward(self, x, mask=None):
  375. x = x + self.drop_path(self.attn(self.norm1(x), mask=mask))
  376. x = x + self.drop_path(self.mlp(self.norm2(x)))
  377. return x
  378. class GroupingLayer(nn.Module):
  379. """A Transformer layer with Grouping Block for one stage.
  380. Args:
  381. dim (int): Number of input channels.
  382. num_input_token (int): Input resolution.
  383. depth (int): Number of blocks.
  384. num_heads (int): Number of attention heads.
  385. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  386. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  387. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  388. drop (float, optional): Dropout rate. Default: 0.0
  389. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  390. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  391. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  392. downsample (nn.Module | None, optional): Downsample layer at the end of the layer.
  393. In GroupViT setting, Grouping Block serves as the downsampling layer. Default: None
  394. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  395. group_projector (nn.Module | None, optional): Projector for the grouping layer. Default: None.
  396. zero_init_group_token (bool): Whether to initialize the grouping token to 0. Default: False.
  397. """
  398. def __init__(self,
  399. dim,
  400. num_input_token,
  401. depth,
  402. num_heads,
  403. num_group_token,
  404. mlp_ratio=4.,
  405. qkv_bias=True,
  406. qk_scale=None,
  407. drop=0.,
  408. attn_drop=0.,
  409. drop_path=0.,
  410. norm_layer=nn.LayerNorm,
  411. downsample=None,
  412. use_checkpoint=False,
  413. group_projector=None,
  414. zero_init_group_token=False):
  415. super().__init__()
  416. self.dim = dim
  417. self.input_length = num_input_token
  418. self.depth = depth
  419. self.use_checkpoint = use_checkpoint
  420. self.num_group_token = num_group_token
  421. if num_group_token > 0:
  422. self.group_token = nn.Parameter(torch.zeros(1, num_group_token, dim))
  423. if not zero_init_group_token:
  424. trunc_normal_(self.group_token, std=.02)
  425. else:
  426. self.group_token = None
  427. # build blocks
  428. self.depth = depth
  429. blocks = []
  430. for i in range(depth):
  431. blocks.append(
  432. AttnBlock(
  433. dim=dim,
  434. num_heads=num_heads,
  435. mlp_ratio=mlp_ratio,
  436. qkv_bias=qkv_bias,
  437. qk_scale=qk_scale,
  438. drop=drop,
  439. attn_drop=attn_drop,
  440. drop_path=drop_path[i],
  441. norm_layer=norm_layer))
  442. self.blocks = nn.ModuleList(blocks)
  443. self.downsample = downsample
  444. self.input_resolution = num_input_token
  445. self.use_checkpoint = use_checkpoint
  446. self.group_projector = group_projector
  447. @property
  448. def with_group_token(self):
  449. return self.group_token is not None
  450. def extra_repr(self):
  451. return f'dim={self.dim}, \n' \
  452. f'input_resolution={self.input_resolution}, \n' \
  453. f'depth={self.depth}, \n' \
  454. f'num_group_token={self.num_group_token}, \n'
  455. def split_x(self, x):
  456. if self.with_group_token:
  457. return x[:, :-self.num_group_token], x[:, -self.num_group_token:]
  458. else:
  459. return x, None
  460. def concat_x(self, x, group_token=None):
  461. if group_token is None:
  462. return x
  463. return torch.cat([x, group_token], dim=1)
  464. def forward(self, x, prev_group_token=None, return_attn=False):
  465. """
  466. Args:
  467. x (torch.Tensor): image tokens, [B, L, C]
  468. prev_group_token (torch.Tensor): group tokens, [B, S_1, C]
  469. return_attn (bool): whether to return attention maps
  470. """
  471. if self.with_group_token:
  472. group_token = self.group_token.expand(x.size(0), -1, -1)
  473. if self.group_projector is not None:
  474. group_token = group_token + self.group_projector(prev_group_token)
  475. else:
  476. group_token = None
  477. B, L, C = x.shape
  478. cat_x = self.concat_x(x, group_token)
  479. for blk_idx, blk in enumerate(self.blocks):
  480. if self.use_checkpoint:
  481. cat_x = checkpoint.checkpoint(blk, cat_x)
  482. else:
  483. cat_x = blk(cat_x)
  484. x, group_token = self.split_x(cat_x)
  485. attn_dict = None
  486. if self.downsample is not None:
  487. x, attn_dict = self.downsample(x, group_token, return_attn=return_attn)
  488. return x, group_token, attn_dict
  489. class PatchEmbed(nn.Module):
  490. """Image to Patch Embedding."""
  491. def __init__(self, img_size=224, kernel_size=7, stride=4, padding=2, in_chans=3, embed_dim=96, norm_layer=None):
  492. super().__init__()
  493. img_size = to_2tuple(img_size)
  494. kernel_size = to_2tuple(kernel_size)
  495. stride = to_2tuple(stride)
  496. padding = to_2tuple(padding)
  497. self.img_size = img_size
  498. self.patches_resolution = (
  499. int((img_size[1] + 2 * padding[1] - kernel_size[1]) / stride[1] + 1),
  500. int((img_size[0] + 2 * padding[0] - kernel_size[0]) / stride[0] + 1),
  501. )
  502. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
  503. if norm_layer is not None:
  504. self.norm = norm_layer(embed_dim)
  505. else:
  506. self.norm = None
  507. @property
  508. def num_patches(self):
  509. return self.patches_resolution[1] * self.patches_resolution[0]
  510. def forward(self, x):
  511. B, C, H, W = x.shape
  512. if self.training:
  513. # FIXME look at relaxing size constraints
  514. assert H == self.img_size[0] and W == self.img_size[1], \
  515. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  516. x = self.proj(x)
  517. hw_shape = x.shape[2:]
  518. x = x.flatten(2).transpose(1, 2)
  519. if self.norm is not None:
  520. x = self.norm(x)
  521. return x, hw_shape
  522. @MODELS.register_module()
  523. class GroupViT(nn.Module):
  524. r""" Group Vision Transformer
  525. A PyTorch impl of : `GroupViT: Semantic Segmentation Emerges from Text Supervision` -
  526. https://arxiv.org/pdf/2202.11094.pdf
  527. Args:
  528. img_size (int | tuple[int]): Input image size. Default 224
  529. patch_size (int | tuple[int]): Patch size. Default: 4
  530. in_chans (int): Number of input image channels. Default: 3
  531. num_classes (int): Number of classes for classification head. Default: 0
  532. embed_dim (int): Patch embedding dimension. Default: 384
  533. embed_factors (list[int]): Embedding dim multipliers for each stage.
  534. depths (list[int]): Depth of each stage
  535. num_heads (list[int]): Number of heads for each stage
  536. num_group_tokens (list[int]): Number of group tokens for each stage
  537. num_output_group (list[int]): Number of output groups for each stage
  538. hard_assignment (bool): Whether to use hard assignment or not. Default: True
  539. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  540. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  541. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
  542. drop_rate (float): Dropout rate. Default: 0
  543. attn_drop_rate (float): Attention dropout rate. Default: 0
  544. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  545. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  546. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
  547. pos_embed_type (str): Type of positional embedding. Default: 'simple'
  548. freeze_patch_embed (bool): Whether to freeze patch embedding. Default: False
  549. """
  550. def __init__(self,
  551. img_size=224,
  552. patch_size=16,
  553. in_chans=3,
  554. num_classes=0,
  555. embed_dim=384,
  556. embed_factors=[1, 1, 1],
  557. depths=[6, 3, 3],
  558. num_heads=[6, 6, 6],
  559. num_group_tokens=[64, 8, 0],
  560. num_output_groups=[64, 8],
  561. hard_assignment=True,
  562. mlp_ratio=4.,
  563. qkv_bias=True,
  564. qk_scale=None,
  565. drop_rate=0.,
  566. attn_drop_rate=0.,
  567. drop_path_rate=0.1,
  568. patch_norm=True,
  569. use_checkpoint=False,
  570. pos_embed_type='simple',
  571. freeze_patch_embed=False):
  572. super().__init__()
  573. assert patch_size in [4, 8, 16]
  574. self.num_classes = num_classes
  575. assert len(embed_factors) == len(depths) == len(num_group_tokens)
  576. assert all(_ == 0 for _ in num_heads) or len(depths) == len(num_heads)
  577. assert len(depths) - 1 == len(num_output_groups)
  578. self.num_layers = len(depths)
  579. self.embed_dim = embed_dim
  580. self.patch_norm = patch_norm
  581. self.num_features = int(embed_dim * embed_factors[len(depths) - 1])
  582. self.mlp_ratio = mlp_ratio
  583. self.qkv_bias = qkv_bias
  584. self.qk_scale = qk_scale
  585. self.drop_rate = drop_rate
  586. self.attn_drop_rate = attn_drop_rate
  587. self.drop_path_rate = drop_path_rate
  588. self.num_group_tokens = num_group_tokens
  589. self.num_output_groups = num_output_groups
  590. self.pos_embed_type = pos_embed_type
  591. assert pos_embed_type in ['simple', 'fourier']
  592. norm_layer = nn.LayerNorm
  593. # split image into non-overlapping patches
  594. self.patch_embed = PatchEmbed(
  595. img_size=img_size,
  596. kernel_size=patch_size,
  597. stride=patch_size,
  598. padding=0,
  599. in_chans=in_chans,
  600. embed_dim=embed_dim,
  601. norm_layer=norm_layer if self.patch_norm else None)
  602. num_patches = self.patch_embed.num_patches
  603. patches_resolution = self.patch_embed.patches_resolution
  604. self.patches_resolution = patches_resolution
  605. self.avgpool = nn.AdaptiveAvgPool1d(1)
  606. if pos_embed_type == 'simple':
  607. self.pos_embed = self.build_simple_position_embedding()
  608. elif pos_embed_type == 'fourier':
  609. self.pos_embed = self.build_2d_sincos_position_embedding()
  610. else:
  611. raise ValueError
  612. if freeze_patch_embed:
  613. for param in self.patch_embed.parameters():
  614. param.requires_grad = False
  615. self.pos_embed.requires_grad = False
  616. self.pos_drop = nn.Dropout(p=drop_rate)
  617. # stochastic depth
  618. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
  619. num_input_token = num_patches
  620. num_output_token = num_input_token
  621. # build layers
  622. self.layers = nn.ModuleList()
  623. for i_layer in range(self.num_layers):
  624. dim = int(embed_dim * embed_factors[i_layer])
  625. downsample = None
  626. if i_layer < self.num_layers - 1:
  627. out_dim = embed_dim * embed_factors[i_layer + 1]
  628. downsample = GroupingBlock(
  629. dim=dim,
  630. out_dim=out_dim,
  631. num_heads=num_heads[i_layer],
  632. num_group_token=num_group_tokens[i_layer],
  633. num_output_group=num_output_groups[i_layer],
  634. norm_layer=norm_layer,
  635. hard=hard_assignment,
  636. gumbel=hard_assignment)
  637. num_output_token = num_output_groups[i_layer]
  638. if i_layer > 0 and num_group_tokens[i_layer] > 0:
  639. prev_dim = int(embed_dim * embed_factors[i_layer - 1])
  640. group_projector = nn.Sequential(
  641. norm_layer(prev_dim),
  642. MixerMlp(num_group_tokens[i_layer - 1], prev_dim // 2, num_group_tokens[i_layer]))
  643. if dim != prev_dim:
  644. group_projector = nn.Sequential(group_projector, norm_layer(prev_dim),
  645. nn.Linear(prev_dim, dim, bias=False))
  646. else:
  647. group_projector = None
  648. layer = GroupingLayer(
  649. dim=dim,
  650. num_input_token=num_input_token,
  651. depth=depths[i_layer],
  652. num_heads=num_heads[i_layer],
  653. num_group_token=num_group_tokens[i_layer],
  654. mlp_ratio=self.mlp_ratio,
  655. qkv_bias=qkv_bias,
  656. qk_scale=qk_scale,
  657. drop=drop_rate,
  658. attn_drop=attn_drop_rate,
  659. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  660. norm_layer=norm_layer,
  661. downsample=downsample,
  662. use_checkpoint=use_checkpoint,
  663. group_projector=group_projector,
  664. # only zero init group token if we have a projection
  665. zero_init_group_token=group_projector is not None)
  666. self.layers.append(layer)
  667. if i_layer < self.num_layers - 1:
  668. num_input_token = num_output_token
  669. self.norm = norm_layer(self.num_features)
  670. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  671. self.apply(self._init_weights)
  672. def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
  673. if self.pos_embed_type == 'simple' and 'pos_embed' in state_dict:
  674. load_pos_embed = state_dict['pos_embed']
  675. pos_embed = self.pos_embed
  676. if load_pos_embed.shape != pos_embed.shape:
  677. H_new = int(self.patch_embed.num_patches**0.5)
  678. W_new = H_new
  679. H_ori = int(load_pos_embed.shape[1]**0.5)
  680. W_ori = H_ori
  681. load_pos_embed = F.interpolate(
  682. rearrange(load_pos_embed, 'b (h w) c -> b c h w', h=H_ori, w=W_ori, b=1),
  683. size=(H_new, W_new),
  684. mode='bicubic',
  685. align_corners=False)
  686. load_pos_embed = rearrange(load_pos_embed, 'b c h w -> b (h w) c', h=H_new, w=W_new)
  687. state_dict['pos_embed'] = load_pos_embed
  688. return super().load_state_dict(state_dict, strict)
  689. def build_simple_position_embedding(self):
  690. pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches, self.embed_dim))
  691. trunc_normal_(pos_embed, std=.02)
  692. return pos_embed
  693. def build_2d_sincos_position_embedding(self, temperature=10000.):
  694. h, w = self.patch_embed.patches_resolution
  695. grid_w = torch.arange(w, dtype=torch.float32)
  696. grid_h = torch.arange(h, dtype=torch.float32)
  697. grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
  698. assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  699. pos_dim = self.embed_dim // 4
  700. omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
  701. omega = 1. / (temperature**omega)
  702. out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
  703. out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
  704. pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
  705. pos_embed = nn.Parameter(pos_emb)
  706. pos_embed.requires_grad = False
  707. return pos_embed
  708. @property
  709. def width(self):
  710. return self.num_features
  711. def _init_weights(self, m):
  712. if isinstance(m, nn.Linear):
  713. trunc_normal_(m.weight, std=.02)
  714. if isinstance(m, nn.Linear) and m.bias is not None:
  715. nn.init.constant_(m.bias, 0)
  716. elif isinstance(m, nn.LayerNorm):
  717. nn.init.constant_(m.bias, 0)
  718. nn.init.constant_(m.weight, 1.0)
  719. def get_pos_embed(self, B, H, W):
  720. if self.training:
  721. return self.pos_embed
  722. pos_embed = self.pos_embed
  723. pos_embed = interpolate_pos_encoding(pos_embed, H, W)
  724. return pos_embed
  725. def forward_features(self, x, *, return_attn=False):
  726. B = x.shape[0]
  727. x, hw_shape = self.patch_embed(x)
  728. x = x + self.get_pos_embed(B, *hw_shape)
  729. x = self.pos_drop(x)
  730. group_token = None
  731. attn_dict_list = []
  732. for layer in self.layers:
  733. x, group_token, attn_dict = layer(x, group_token, return_attn=return_attn)
  734. attn_dict_list.append(attn_dict)
  735. x = self.norm(x)
  736. return x, group_token, attn_dict_list
  737. def forward_image_head(self, x):
  738. """
  739. Args:
  740. x: shape [B, L, C]
  741. Returns:
  742. """
  743. # [B, L, C]
  744. x = self.avgpool(x.transpose(1, 2)) # B C 1
  745. x = torch.flatten(x, 1)
  746. x = self.head(x)
  747. return x
  748. def forward(self, x, *, return_feat=False, return_attn=False, as_dict=False):
  749. x, group_token, attn_dicts = self.forward_features(x, return_attn=return_attn)
  750. x_feat = x if return_feat else None
  751. outs = Result(as_dict=as_dict)
  752. outs.append(self.forward_image_head(x), name='x')
  753. if return_feat:
  754. outs.append(x_feat, name='feat')
  755. if return_attn:
  756. outs.append(attn_dicts, name='attn_dicts')
  757. return outs.as_return()