group_vit.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # -------------------------------------------------------------------------
  10. from collections import OrderedDict
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torch.utils.checkpoint as checkpoint
  15. from einops import rearrange
  16. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  17. from .builder import MODELS
  18. from .misc import Result, interpolate_pos_encoding
  19. class Mlp(nn.Module):
  20. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  21. super().__init__()
  22. out_features = out_features or in_features
  23. hidden_features = hidden_features or in_features
  24. self.fc1 = nn.Linear(in_features, hidden_features)
  25. self.act = act_layer()
  26. self.fc2 = nn.Linear(hidden_features, out_features)
  27. self.drop = nn.Dropout(drop)
  28. def forward(self, x):
  29. x = self.fc1(x)
  30. x = self.act(x)
  31. x = self.drop(x)
  32. x = self.fc2(x)
  33. x = self.drop(x)
  34. return x
  35. class MixerMlp(Mlp):
  36. def forward(self, x):
  37. return super().forward(x.transpose(1, 2)).transpose(1, 2)
  38. def hard_softmax(logits, dim):
  39. y_soft = logits.softmax(dim)
  40. # Straight through.
  41. index = y_soft.max(dim, keepdim=True)[1]
  42. y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
  43. ret = y_hard - y_soft.detach() + y_soft
  44. return ret
  45. def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> torch.Tensor:
  46. # _gumbels = (-torch.empty_like(
  47. # logits,
  48. # memory_format=torch.legacy_contiguous_format).exponential_().log()
  49. # ) # ~Gumbel(0,1)
  50. # more stable https://github.com/pytorch/pytorch/issues/41663
  51. gumbel_dist = torch.distributions.gumbel.Gumbel(
  52. torch.tensor(0., device=logits.device, dtype=logits.dtype),
  53. torch.tensor(1., device=logits.device, dtype=logits.dtype))
  54. gumbels = gumbel_dist.sample(logits.shape)
  55. gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
  56. y_soft = gumbels.softmax(dim)
  57. if hard:
  58. # Straight through.
  59. index = y_soft.max(dim, keepdim=True)[1]
  60. y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
  61. ret = y_hard - y_soft.detach() + y_soft
  62. else:
  63. # Reparametrization trick.
  64. ret = y_soft
  65. return ret
  66. class AssignAttention(nn.Module):
  67. def __init__(self,
  68. dim,
  69. num_heads=1,
  70. qkv_bias=False,
  71. qk_scale=None,
  72. attn_drop=0.,
  73. proj_drop=0.,
  74. hard=True,
  75. gumbel=False,
  76. gumbel_tau=1.,
  77. sum_assign=False,
  78. assign_eps=1.):
  79. super().__init__()
  80. self.num_heads = num_heads
  81. head_dim = dim // num_heads
  82. self.scale = qk_scale or head_dim**-0.5
  83. self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
  84. self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
  85. self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
  86. self.attn_drop = nn.Dropout(attn_drop)
  87. self.proj = nn.Linear(dim, dim)
  88. self.proj_drop = nn.Dropout(proj_drop)
  89. self.hard = hard
  90. self.gumbel = gumbel
  91. self.gumbel_tau = gumbel_tau
  92. self.sum_assign = sum_assign
  93. self.assign_eps = assign_eps
  94. def get_attn(self, attn, gumbel=None, hard=None):
  95. if gumbel is None:
  96. gumbel = self.gumbel
  97. if hard is None:
  98. hard = self.hard
  99. attn_dim = -2
  100. if gumbel and self.training:
  101. attn = gumbel_softmax(attn, dim=attn_dim, hard=hard, tau=self.gumbel_tau)
  102. else:
  103. if hard:
  104. attn = hard_softmax(attn, dim=attn_dim)
  105. else:
  106. attn = F.softmax(attn, dim=attn_dim)
  107. return attn
  108. def forward(self, query, key=None, *, value=None, return_attn=False):
  109. B, N, C = query.shape
  110. if key is None:
  111. key = query
  112. if value is None:
  113. value = key
  114. S = key.size(1)
  115. # [B, nh, N, C//nh]
  116. q = rearrange(self.q_proj(query), 'b n (h c)-> b h n c', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
  117. # [B, nh, S, C//nh]
  118. k = rearrange(self.k_proj(key), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
  119. # [B, nh, S, C//nh]
  120. v = rearrange(self.v_proj(value), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
  121. # [B, nh, N, S]
  122. raw_attn = (q @ k.transpose(-2, -1)) * self.scale
  123. attn = self.get_attn(raw_attn)
  124. if return_attn:
  125. hard_attn = attn.clone()
  126. soft_attn = self.get_attn(raw_attn, gumbel=False, hard=False)
  127. attn_dict = {'hard': hard_attn, 'soft': soft_attn}
  128. else:
  129. attn_dict = None
  130. if not self.sum_assign:
  131. attn = attn / (attn.sum(dim=-1, keepdim=True) + self.assign_eps)
  132. attn = self.attn_drop(attn)
  133. assert attn.shape == (B, self.num_heads, N, S)
  134. # [B, nh, N, C//nh] <- [B, nh, N, S] @ [B, nh, S, C//nh]
  135. out = rearrange(attn @ v, 'b h n c -> b n (h c)', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
  136. out = self.proj(out)
  137. out = self.proj_drop(out)
  138. return out, attn_dict
  139. def extra_repr(self):
  140. return f'num_heads: {self.num_heads}, \n' \
  141. f'hard: {self.hard}, \n' \
  142. f'gumbel: {self.gumbel}, \n' \
  143. f'sum_assign={self.sum_assign}, \n' \
  144. f'gumbel_tau: {self.gumbel_tau}, \n' \
  145. f'assign_eps: {self.assign_eps}'
  146. class GroupingBlock(nn.Module):
  147. """Grouping Block to group similar segments together.
  148. Args:
  149. dim (int): Dimension of the input.
  150. out_dim (int): Dimension of the output.
  151. num_heads (int): Number of heads in the grouping attention.
  152. num_output_group (int): Number of output groups.
  153. norm_layer (nn.Module): Normalization layer to use.
  154. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  155. hard (bool): Whether to use hard or soft assignment. Default: True
  156. gumbel (bool): Whether to use gumbel softmax. Default: True
  157. sum_assign (bool): Whether to sum assignment or average. Default: False
  158. assign_eps (float): Epsilon to avoid divide by zero. Default: 1
  159. gum_tau (float): Temperature for gumbel softmax. Default: 1
  160. """
  161. def __init__(self,
  162. *,
  163. dim,
  164. out_dim,
  165. num_heads,
  166. num_group_token,
  167. num_output_group,
  168. norm_layer,
  169. mlp_ratio=(0.5, 4.0),
  170. hard=True,
  171. gumbel=True,
  172. sum_assign=False,
  173. assign_eps=1.,
  174. gumbel_tau=1.):
  175. super(GroupingBlock, self).__init__()
  176. self.dim = dim
  177. self.hard = hard
  178. self.gumbel = gumbel
  179. self.sum_assign = sum_assign
  180. self.num_output_group = num_output_group
  181. # norm on group_tokens
  182. self.norm_tokens = norm_layer(dim)
  183. tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
  184. self.mlp_inter = Mlp(num_group_token, tokens_dim, num_output_group)
  185. self.norm_post_tokens = norm_layer(dim)
  186. # norm on x
  187. self.norm_x = norm_layer(dim)
  188. self.pre_assign_attn = CrossAttnBlock(
  189. dim=dim, num_heads=num_heads, mlp_ratio=4, qkv_bias=True, norm_layer=norm_layer, post_norm=True)
  190. self.assign = AssignAttention(
  191. dim=dim,
  192. num_heads=1,
  193. qkv_bias=True,
  194. hard=hard,
  195. gumbel=gumbel,
  196. gumbel_tau=gumbel_tau,
  197. sum_assign=sum_assign,
  198. assign_eps=assign_eps)
  199. self.norm_new_x = norm_layer(dim)
  200. self.mlp_channels = Mlp(dim, channels_dim, out_dim)
  201. if out_dim is not None and dim != out_dim:
  202. self.reduction = nn.Sequential(norm_layer(dim), nn.Linear(dim, out_dim, bias=False))
  203. else:
  204. self.reduction = nn.Identity()
  205. def extra_repr(self):
  206. return f'hard={self.hard}, \n' \
  207. f'gumbel={self.gumbel}, \n' \
  208. f'sum_assign={self.sum_assign}, \n' \
  209. f'num_output_group={self.num_output_group}, \n '
  210. def project_group_token(self, group_tokens):
  211. """
  212. Args:
  213. group_tokens (torch.Tensor): group tokens, [B, S_1, C]
  214. inter_weight (torch.Tensor): [B, S_2, S_1], S_2 is the new number of
  215. group tokens, it's already softmaxed along dim=-1
  216. Returns:
  217. projected_group_tokens (torch.Tensor): [B, S_2, C]
  218. """
  219. # [B, S_2, C] <- [B, S_1, C]
  220. projected_group_tokens = self.mlp_inter(group_tokens.transpose(1, 2)).transpose(1, 2)
  221. projected_group_tokens = self.norm_post_tokens(projected_group_tokens)
  222. return projected_group_tokens
  223. def forward(self, x, group_tokens, return_attn=False):
  224. """
  225. Args:
  226. x (torch.Tensor): image tokens, [B, L, C]
  227. group_tokens (torch.Tensor): group tokens, [B, S_1, C]
  228. return_attn (bool): whether to return attention map
  229. Returns:
  230. new_x (torch.Tensor): [B, S_2, C], S_2 is the new number of
  231. group tokens
  232. """
  233. group_tokens = self.norm_tokens(group_tokens)
  234. x = self.norm_x(x)
  235. # [B, S_2, C]
  236. projected_group_tokens = self.project_group_token(group_tokens)
  237. projected_group_tokens = self.pre_assign_attn(projected_group_tokens, x)
  238. new_x, attn_dict = self.assign(projected_group_tokens, x, return_attn=return_attn)
  239. new_x += projected_group_tokens
  240. new_x = self.reduction(new_x) + self.mlp_channels(self.norm_new_x(new_x))
  241. return new_x, attn_dict
  242. class Attention(nn.Module):
  243. def __init__(self,
  244. dim,
  245. num_heads,
  246. out_dim=None,
  247. qkv_bias=False,
  248. qk_scale=None,
  249. attn_drop=0.,
  250. proj_drop=0.,
  251. qkv_fuse=False):
  252. super().__init__()
  253. if out_dim is None:
  254. out_dim = dim
  255. self.num_heads = num_heads
  256. head_dim = dim // num_heads
  257. self.scale = qk_scale or head_dim**-0.5
  258. self.qkv_fuse = qkv_fuse
  259. if qkv_fuse:
  260. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  261. else:
  262. self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
  263. self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
  264. self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
  265. self.attn_drop = nn.Dropout(attn_drop)
  266. self.proj = nn.Linear(dim, out_dim)
  267. self.proj_drop = nn.Dropout(proj_drop)
  268. def extra_repr(self):
  269. return f'num_heads={self.num_heads}, \n' \
  270. f'qkv_bias={self.scale}, \n' \
  271. f'qkv_fuse={self.qkv_fuse}'
  272. def forward(self, query, key=None, *, value=None, mask=None):
  273. if self.qkv_fuse:
  274. assert key is None
  275. assert value is None
  276. x = query
  277. B, N, C = x.shape
  278. S = N
  279. # [3, B, nh, N, C//nh]
  280. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  281. # [B, nh, N, C//nh]
  282. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  283. else:
  284. B, N, C = query.shape
  285. if key is None:
  286. key = query
  287. if value is None:
  288. value = key
  289. S = key.size(1)
  290. # [B, nh, N, C//nh]
  291. q = rearrange(self.q_proj(query), 'b n (h c)-> b h n c', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
  292. # [B, nh, S, C//nh]
  293. k = rearrange(self.k_proj(key), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
  294. # [B, nh, S, C//nh]
  295. v = rearrange(self.v_proj(value), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
  296. # [B, nh, N, S]
  297. attn = (q @ k.transpose(-2, -1)) * self.scale
  298. if mask is not None:
  299. attn = attn + mask.unsqueeze(dim=1)
  300. attn = attn.softmax(dim=-1)
  301. else:
  302. attn = attn.softmax(dim=-1)
  303. attn = self.attn_drop(attn)
  304. assert attn.shape == (B, self.num_heads, N, S)
  305. # [B, nh, N, C//nh] -> [B, N, C]
  306. # out = (attn @ v).transpose(1, 2).reshape(B, N, C)
  307. out = rearrange(attn @ v, 'b h n c -> b n (h c)', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
  308. out = self.proj(out)
  309. out = self.proj_drop(out)
  310. return out
  311. class CrossAttnBlock(nn.Module):
  312. def __init__(self,
  313. dim,
  314. num_heads,
  315. mlp_ratio=4.,
  316. qkv_bias=False,
  317. qk_scale=None,
  318. drop=0.,
  319. attn_drop=0.,
  320. drop_path=0.,
  321. act_layer=nn.GELU,
  322. norm_layer=nn.LayerNorm,
  323. post_norm=False):
  324. super().__init__()
  325. if post_norm:
  326. self.norm_post = norm_layer(dim)
  327. self.norm_q = nn.Identity()
  328. self.norm_k = nn.Identity()
  329. else:
  330. self.norm_q = norm_layer(dim)
  331. self.norm_k = norm_layer(dim)
  332. self.norm_post = nn.Identity()
  333. self.attn = Attention(
  334. dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  335. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  336. self.norm2 = norm_layer(dim)
  337. mlp_hidden_dim = int(dim * mlp_ratio)
  338. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  339. def forward(self, query, key, *, mask=None):
  340. x = query
  341. x = x + self.drop_path(self.attn(self.norm_q(query), self.norm_k(key), mask=mask))
  342. x = x + self.drop_path(self.mlp(self.norm2(x)))
  343. x = self.norm_post(x)
  344. return x
  345. class AttnBlock(nn.Module):
  346. def __init__(self,
  347. dim,
  348. num_heads,
  349. mlp_ratio=4.,
  350. qkv_bias=False,
  351. qk_scale=None,
  352. drop=0.,
  353. attn_drop=0.,
  354. drop_path=0.,
  355. act_layer=nn.GELU,
  356. norm_layer=nn.LayerNorm):
  357. super().__init__()
  358. self.norm1 = norm_layer(dim)
  359. self.attn = Attention(
  360. dim,
  361. num_heads=num_heads,
  362. qkv_bias=qkv_bias,
  363. qk_scale=qk_scale,
  364. attn_drop=attn_drop,
  365. proj_drop=drop,
  366. qkv_fuse=True)
  367. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  368. self.norm2 = norm_layer(dim)
  369. mlp_hidden_dim = int(dim * mlp_ratio)
  370. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  371. def forward(self, x, mask=None):
  372. x = x + self.drop_path(self.attn(self.norm1(x), mask=mask))
  373. x = x + self.drop_path(self.mlp(self.norm2(x)))
  374. return x
  375. class GroupingLayer(nn.Module):
  376. """A Transformer layer with Grouping Block for one stage.
  377. Args:
  378. dim (int): Number of input channels.
  379. num_input_token (int): Input resolution.
  380. depth (int): Number of blocks.
  381. num_heads (int): Number of attention heads.
  382. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  383. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  384. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  385. drop (float, optional): Dropout rate. Default: 0.0
  386. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  387. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  388. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  389. downsample (nn.Module | None, optional): Downsample layer at the end of the layer.
  390. In GroupViT setting, Grouping Block serves as the downsampling layer. Default: None
  391. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  392. group_projector (nn.Module | None, optional): Projector for the grouping layer. Default: None.
  393. zero_init_group_token (bool): Whether to initialize the grouping token to 0. Default: False.
  394. """
  395. def __init__(self,
  396. dim,
  397. num_input_token,
  398. depth,
  399. num_heads,
  400. num_group_token,
  401. mlp_ratio=4.,
  402. qkv_bias=True,
  403. qk_scale=None,
  404. drop=0.,
  405. attn_drop=0.,
  406. drop_path=0.,
  407. norm_layer=nn.LayerNorm,
  408. downsample=None,
  409. use_checkpoint=False,
  410. group_projector=None,
  411. zero_init_group_token=False):
  412. super().__init__()
  413. self.dim = dim
  414. self.input_length = num_input_token
  415. self.depth = depth
  416. self.use_checkpoint = use_checkpoint
  417. self.num_group_token = num_group_token
  418. if num_group_token > 0:
  419. self.group_token = nn.Parameter(torch.zeros(1, num_group_token, dim))
  420. if not zero_init_group_token:
  421. trunc_normal_(self.group_token, std=.02)
  422. else:
  423. self.group_token = None
  424. # build blocks
  425. self.depth = depth
  426. blocks = []
  427. for i in range(depth):
  428. blocks.append(
  429. AttnBlock(
  430. dim=dim,
  431. num_heads=num_heads,
  432. mlp_ratio=mlp_ratio,
  433. qkv_bias=qkv_bias,
  434. qk_scale=qk_scale,
  435. drop=drop,
  436. attn_drop=attn_drop,
  437. drop_path=drop_path[i],
  438. norm_layer=norm_layer))
  439. self.blocks = nn.ModuleList(blocks)
  440. self.downsample = downsample
  441. self.input_resolution = num_input_token
  442. self.use_checkpoint = use_checkpoint
  443. self.group_projector = group_projector
  444. @property
  445. def with_group_token(self):
  446. return self.group_token is not None
  447. def extra_repr(self):
  448. return f'dim={self.dim}, \n' \
  449. f'input_resolution={self.input_resolution}, \n' \
  450. f'depth={self.depth}, \n' \
  451. f'num_group_token={self.num_group_token}, \n'
  452. def split_x(self, x):
  453. if self.with_group_token:
  454. return x[:, :-self.num_group_token], x[:, -self.num_group_token:]
  455. else:
  456. return x, None
  457. def concat_x(self, x, group_token=None):
  458. if group_token is None:
  459. return x
  460. return torch.cat([x, group_token], dim=1)
  461. def forward(self, x, prev_group_token=None, return_attn=False):
  462. """
  463. Args:
  464. x (torch.Tensor): image tokens, [B, L, C]
  465. prev_group_token (torch.Tensor): group tokens, [B, S_1, C]
  466. return_attn (bool): whether to return attention maps
  467. """
  468. if self.with_group_token:
  469. group_token = self.group_token.expand(x.size(0), -1, -1)
  470. if self.group_projector is not None:
  471. group_token = group_token + self.group_projector(prev_group_token)
  472. else:
  473. group_token = None
  474. B, L, C = x.shape
  475. cat_x = self.concat_x(x, group_token)
  476. for blk_idx, blk in enumerate(self.blocks):
  477. if self.use_checkpoint:
  478. cat_x = checkpoint.checkpoint(blk, cat_x)
  479. else:
  480. cat_x = blk(cat_x)
  481. x, group_token = self.split_x(cat_x)
  482. attn_dict = None
  483. if self.downsample is not None:
  484. x, attn_dict = self.downsample(x, group_token, return_attn=return_attn)
  485. return x, group_token, attn_dict
  486. class PatchEmbed(nn.Module):
  487. """Image to Patch Embedding."""
  488. def __init__(self, img_size=224, kernel_size=7, stride=4, padding=2, in_chans=3, embed_dim=96, norm_layer=None):
  489. super().__init__()
  490. img_size = to_2tuple(img_size)
  491. kernel_size = to_2tuple(kernel_size)
  492. stride = to_2tuple(stride)
  493. padding = to_2tuple(padding)
  494. self.img_size = img_size
  495. self.patches_resolution = (
  496. int((img_size[1] + 2 * padding[1] - kernel_size[1]) / stride[1] + 1),
  497. int((img_size[0] + 2 * padding[0] - kernel_size[0]) / stride[0] + 1),
  498. )
  499. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
  500. if norm_layer is not None:
  501. self.norm = norm_layer(embed_dim)
  502. else:
  503. self.norm = None
  504. @property
  505. def num_patches(self):
  506. return self.patches_resolution[1] * self.patches_resolution[0]
  507. def forward(self, x):
  508. B, C, H, W = x.shape
  509. if self.training:
  510. # FIXME look at relaxing size constraints
  511. assert H == self.img_size[0] and W == self.img_size[1], \
  512. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  513. x = self.proj(x)
  514. hw_shape = x.shape[2:]
  515. x = x.flatten(2).transpose(1, 2)
  516. if self.norm is not None:
  517. x = self.norm(x)
  518. return x, hw_shape
  519. @MODELS.register_module()
  520. class GroupViT(nn.Module):
  521. r""" Group Vision Transformer
  522. A PyTorch impl of : `GroupViT: Semantic Segmentation Emerges from Text Supervision` -
  523. https://arxiv.org/pdf/2202.11094.pdf
  524. Args:
  525. img_size (int | tuple[int]): Input image size. Default 224
  526. patch_size (int | tuple[int]): Patch size. Default: 4
  527. in_chans (int): Number of input image channels. Default: 3
  528. num_classes (int): Number of classes for classification head. Default: 0
  529. embed_dim (int): Patch embedding dimension. Default: 384
  530. embed_factors (list[int]): Embedding dim multipliers for each stage.
  531. depths (list[int]): Depth of each stage
  532. num_heads (list[int]): Number of heads for each stage
  533. num_group_tokens (list[int]): Number of group tokens for each stage
  534. num_output_group (list[int]): Number of output groups for each stage
  535. hard_assignment (bool): Whether to use hard assignment or not. Default: True
  536. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  537. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  538. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
  539. drop_rate (float): Dropout rate. Default: 0
  540. attn_drop_rate (float): Attention dropout rate. Default: 0
  541. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  542. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  543. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
  544. pos_embed_type (str): Type of positional embedding. Default: 'simple'
  545. freeze_patch_embed (bool): Whether to freeze patch embedding. Default: False
  546. """
  547. def __init__(self,
  548. img_size=224,
  549. patch_size=16,
  550. in_chans=3,
  551. num_classes=0,
  552. embed_dim=384,
  553. embed_factors=[1, 1, 1],
  554. depths=[6, 3, 3],
  555. num_heads=[6, 6, 6],
  556. num_group_tokens=[64, 8, 0],
  557. num_output_groups=[64, 8],
  558. hard_assignment=True,
  559. mlp_ratio=4.,
  560. qkv_bias=True,
  561. qk_scale=None,
  562. drop_rate=0.,
  563. attn_drop_rate=0.,
  564. drop_path_rate=0.1,
  565. patch_norm=True,
  566. use_checkpoint=False,
  567. pos_embed_type='simple',
  568. freeze_patch_embed=False):
  569. super().__init__()
  570. assert patch_size in [4, 8, 16]
  571. self.num_classes = num_classes
  572. assert len(embed_factors) == len(depths) == len(num_group_tokens)
  573. assert all(_ == 0 for _ in num_heads) or len(depths) == len(num_heads)
  574. assert len(depths) - 1 == len(num_output_groups)
  575. self.num_layers = len(depths)
  576. self.embed_dim = embed_dim
  577. self.patch_norm = patch_norm
  578. self.num_features = int(embed_dim * embed_factors[len(depths) - 1])
  579. self.mlp_ratio = mlp_ratio
  580. self.qkv_bias = qkv_bias
  581. self.qk_scale = qk_scale
  582. self.drop_rate = drop_rate
  583. self.attn_drop_rate = attn_drop_rate
  584. self.drop_path_rate = drop_path_rate
  585. self.num_group_tokens = num_group_tokens
  586. self.num_output_groups = num_output_groups
  587. self.pos_embed_type = pos_embed_type
  588. assert pos_embed_type in ['simple', 'fourier']
  589. norm_layer = nn.LayerNorm
  590. # split image into non-overlapping patches
  591. self.patch_embed = PatchEmbed(
  592. img_size=img_size,
  593. kernel_size=patch_size,
  594. stride=patch_size,
  595. padding=0,
  596. in_chans=in_chans,
  597. embed_dim=embed_dim,
  598. norm_layer=norm_layer if self.patch_norm else None)
  599. num_patches = self.patch_embed.num_patches
  600. patches_resolution = self.patch_embed.patches_resolution
  601. self.patches_resolution = patches_resolution
  602. self.avgpool = nn.AdaptiveAvgPool1d(1)
  603. if pos_embed_type == 'simple':
  604. self.pos_embed = self.build_simple_position_embedding()
  605. elif pos_embed_type == 'fourier':
  606. self.pos_embed = self.build_2d_sincos_position_embedding()
  607. else:
  608. raise ValueError
  609. if freeze_patch_embed:
  610. for param in self.patch_embed.parameters():
  611. param.requires_grad = False
  612. self.pos_embed.requires_grad = False
  613. self.pos_drop = nn.Dropout(p=drop_rate)
  614. # stochastic depth
  615. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
  616. num_input_token = num_patches
  617. num_output_token = num_input_token
  618. # build layers
  619. self.layers = nn.ModuleList()
  620. for i_layer in range(self.num_layers):
  621. dim = int(embed_dim * embed_factors[i_layer])
  622. downsample = None
  623. if i_layer < self.num_layers - 1:
  624. out_dim = embed_dim * embed_factors[i_layer + 1]
  625. downsample = GroupingBlock(
  626. dim=dim,
  627. out_dim=out_dim,
  628. num_heads=num_heads[i_layer],
  629. num_group_token=num_group_tokens[i_layer],
  630. num_output_group=num_output_groups[i_layer],
  631. norm_layer=norm_layer,
  632. hard=hard_assignment,
  633. gumbel=hard_assignment)
  634. num_output_token = num_output_groups[i_layer]
  635. if i_layer > 0 and num_group_tokens[i_layer] > 0:
  636. prev_dim = int(embed_dim * embed_factors[i_layer - 1])
  637. group_projector = nn.Sequential(
  638. norm_layer(prev_dim),
  639. MixerMlp(num_group_tokens[i_layer - 1], prev_dim // 2, num_group_tokens[i_layer]))
  640. if dim != prev_dim:
  641. group_projector = nn.Sequential(group_projector, norm_layer(prev_dim),
  642. nn.Linear(prev_dim, dim, bias=False))
  643. else:
  644. group_projector = None
  645. layer = GroupingLayer(
  646. dim=dim,
  647. num_input_token=num_input_token,
  648. depth=depths[i_layer],
  649. num_heads=num_heads[i_layer],
  650. num_group_token=num_group_tokens[i_layer],
  651. mlp_ratio=self.mlp_ratio,
  652. qkv_bias=qkv_bias,
  653. qk_scale=qk_scale,
  654. drop=drop_rate,
  655. attn_drop=attn_drop_rate,
  656. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  657. norm_layer=norm_layer,
  658. downsample=downsample,
  659. use_checkpoint=use_checkpoint,
  660. group_projector=group_projector,
  661. # only zero init group token if we have a projection
  662. zero_init_group_token=group_projector is not None)
  663. self.layers.append(layer)
  664. if i_layer < self.num_layers - 1:
  665. num_input_token = num_output_token
  666. self.norm = norm_layer(self.num_features)
  667. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  668. self.apply(self._init_weights)
  669. def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
  670. if self.pos_embed_type == 'simple' and 'pos_embed' in state_dict:
  671. load_pos_embed = state_dict['pos_embed']
  672. pos_embed = self.pos_embed
  673. if load_pos_embed.shape != pos_embed.shape:
  674. H_new = int(self.patch_embed.num_patches**0.5)
  675. W_new = H_new
  676. H_ori = int(load_pos_embed.shape[1]**0.5)
  677. W_ori = H_ori
  678. load_pos_embed = F.interpolate(
  679. rearrange(load_pos_embed, 'b (h w) c -> b c h w', h=H_ori, w=W_ori, b=1),
  680. size=(H_new, W_new),
  681. mode='bicubic',
  682. align_corners=False)
  683. load_pos_embed = rearrange(load_pos_embed, 'b c h w -> b (h w) c', h=H_new, w=W_new)
  684. state_dict['pos_embed'] = load_pos_embed
  685. return super().load_state_dict(state_dict, strict)
  686. def build_simple_position_embedding(self):
  687. pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches, self.embed_dim))
  688. trunc_normal_(pos_embed, std=.02)
  689. return pos_embed
  690. def build_2d_sincos_position_embedding(self, temperature=10000.):
  691. h, w = self.patch_embed.patches_resolution
  692. grid_w = torch.arange(w, dtype=torch.float32)
  693. grid_h = torch.arange(h, dtype=torch.float32)
  694. grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
  695. assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  696. pos_dim = self.embed_dim // 4
  697. omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
  698. omega = 1. / (temperature**omega)
  699. out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
  700. out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
  701. pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
  702. pos_embed = nn.Parameter(pos_emb)
  703. pos_embed.requires_grad = False
  704. return pos_embed
  705. @property
  706. def width(self):
  707. return self.num_features
  708. def _init_weights(self, m):
  709. if isinstance(m, nn.Linear):
  710. trunc_normal_(m.weight, std=.02)
  711. if isinstance(m, nn.Linear) and m.bias is not None:
  712. nn.init.constant_(m.bias, 0)
  713. elif isinstance(m, nn.LayerNorm):
  714. nn.init.constant_(m.bias, 0)
  715. nn.init.constant_(m.weight, 1.0)
  716. def get_pos_embed(self, B, H, W):
  717. if self.training:
  718. return self.pos_embed
  719. pos_embed = self.pos_embed
  720. pos_embed = interpolate_pos_encoding(pos_embed, H, W)
  721. return pos_embed
  722. def forward_features(self, x, *, return_attn=False):
  723. B = x.shape[0]
  724. x, hw_shape = self.patch_embed(x)
  725. x = x + self.get_pos_embed(B, *hw_shape)
  726. x = self.pos_drop(x)
  727. group_token = None
  728. attn_dict_list = []
  729. for layer in self.layers:
  730. x, group_token, attn_dict = layer(x, group_token, return_attn=return_attn)
  731. attn_dict_list.append(attn_dict)
  732. x = self.norm(x)
  733. return x, group_token, attn_dict_list
  734. def forward_image_head(self, x):
  735. """
  736. Args:
  737. x: shape [B, L, C]
  738. Returns:
  739. """
  740. # [B, L, C]
  741. x = self.avgpool(x.transpose(1, 2)) # B C 1
  742. x = torch.flatten(x, 1)
  743. x = self.head(x)
  744. return x
  745. def forward(self, x, *, return_feat=False, return_attn=False, as_dict=False):
  746. x, group_token, attn_dicts = self.forward_features(x, return_attn=return_attn)
  747. x_feat = x if return_feat else None
  748. outs = Result(as_dict=as_dict)
  749. outs.append(self.forward_image_head(x), name='x')
  750. if return_feat:
  751. outs.append(x_feat, name='feat')
  752. if return_attn:
  753. outs.append(attn_dicts, name='attn_dicts')
  754. return outs.as_return()