group_vit.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. # Modified by Jilan Xu
  14. # -------------------------------------------------------------------------
  15. from collections import OrderedDict
  16. import torch
  17. import torch.nn as nn
  18. import torch.nn.functional as F
  19. import torch.utils.checkpoint as checkpoint
  20. from einops import rearrange
  21. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  22. from .builder import MODELS
  23. from .misc import Result, interpolate_pos_encoding
  24. from ipdb import set_trace
  25. import clip
  26. import cv2
  27. class Mlp(nn.Module):
  28. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  29. super().__init__()
  30. out_features = out_features or in_features
  31. hidden_features = hidden_features or in_features
  32. self.fc1 = nn.Linear(in_features, hidden_features)
  33. self.act = act_layer()
  34. self.fc2 = nn.Linear(hidden_features, out_features)
  35. self.drop = nn.Dropout(drop)
  36. def forward(self, x):
  37. x = self.fc1(x)
  38. x = self.act(x)
  39. x = self.drop(x)
  40. x = self.fc2(x)
  41. x = self.drop(x)
  42. return x
  43. class MixerMlp(Mlp):
  44. def forward(self, x):
  45. return super().forward(x.transpose(1, 2)).transpose(1, 2)
  46. def hard_softmax(logits, dim):
  47. y_soft = logits.softmax(dim)
  48. # Straight through.
  49. index = y_soft.max(dim, keepdim=True)[1]
  50. y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
  51. ret = y_hard - y_soft.detach() + y_soft
  52. return ret
  53. def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> torch.Tensor:
  54. # _gumbels = (-torch.empty_like(
  55. # logits,
  56. # memory_format=torch.legacy_contiguous_format).exponential_().log()
  57. # ) # ~Gumbel(0,1)
  58. # more stable https://github.com/pytorch/pytorch/issues/41663
  59. gumbel_dist = torch.distributions.gumbel.Gumbel(
  60. torch.tensor(0., device=logits.device, dtype=logits.dtype),
  61. torch.tensor(1., device=logits.device, dtype=logits.dtype))
  62. gumbels = gumbel_dist.sample(logits.shape)
  63. gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
  64. y_soft = gumbels.softmax(dim)
  65. if hard:
  66. # Straight through.
  67. index = y_soft.max(dim, keepdim=True)[1]
  68. y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
  69. ret = y_hard - y_soft.detach() + y_soft
  70. else:
  71. # Reparametrization trick.
  72. ret = y_soft
  73. return ret
  74. class AssignAttention(nn.Module):
  75. def __init__(self,
  76. dim,
  77. num_heads=1,
  78. qkv_bias=False,
  79. qk_scale=None,
  80. attn_drop=0.,
  81. proj_drop=0.,
  82. hard=True,
  83. gumbel=False,
  84. gumbel_tau=1.,
  85. sum_assign=False,
  86. assign_eps=1.):
  87. super().__init__()
  88. self.num_heads = num_heads
  89. head_dim = dim // num_heads
  90. self.scale = qk_scale or head_dim**-0.5
  91. self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
  92. self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
  93. self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
  94. self.attn_drop = nn.Dropout(attn_drop)
  95. self.proj = nn.Linear(dim, dim)
  96. self.proj_drop = nn.Dropout(proj_drop)
  97. self.hard = hard
  98. self.gumbel = gumbel
  99. self.gumbel_tau = gumbel_tau
  100. self.sum_assign = sum_assign
  101. self.assign_eps = assign_eps
  102. def get_attn(self, attn, gumbel=None, hard=None):
  103. if gumbel is None:
  104. gumbel = self.gumbel
  105. if hard is None:
  106. hard = self.hard
  107. attn_dim = -2
  108. if gumbel and self.training:
  109. attn = gumbel_softmax(attn, dim=attn_dim, hard=hard, tau=self.gumbel_tau)
  110. else:
  111. if hard:
  112. attn = hard_softmax(attn, dim=attn_dim)
  113. else:
  114. attn = F.softmax(attn, dim=attn_dim)
  115. return attn
  116. def forward(self, query, key=None, *, value=None, return_attn=False, mask=None):
  117. B, N, C = query.shape
  118. if key is None:
  119. key = query
  120. if value is None:
  121. value = key
  122. S = key.size(1)
  123. # [B, nh, N, C//nh]
  124. q = rearrange(self.q_proj(query), 'b n (h c)-> b h n c', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
  125. # [B, nh, S, C//nh]
  126. k = rearrange(self.k_proj(key), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
  127. # [B, nh, S, C//nh]
  128. v = rearrange(self.v_proj(value), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
  129. # [B, nh, N, S]
  130. raw_attn = (q @ k.transpose(-2, -1)) * self.scale
  131. attn = self.get_attn(raw_attn)
  132. if return_attn:
  133. hard_attn = attn.clone()
  134. soft_attn = self.get_attn(raw_attn, gumbel=False, hard=False)
  135. attn_dict = {'hard': hard_attn, 'soft': soft_attn, 'rawk': key, 'rawq':query, 'k':k, 'q':q}
  136. else:
  137. attn_dict = None
  138. if not self.sum_assign:
  139. attn = attn / (attn.sum(dim=-1, keepdim=True) + self.assign_eps)
  140. attn = self.attn_drop(attn)
  141. assert attn.shape == (B, self.num_heads, N, S)
  142. # [B, nh, N, C//nh] <- [B, nh, N, S] @ [B, nh, S, C//nh]
  143. out = rearrange(attn @ v, 'b h n c -> b n (h c)', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
  144. out = self.proj(out)
  145. out = self.proj_drop(out)
  146. return out, attn_dict
  147. def extra_repr(self):
  148. return f'num_heads: {self.num_heads}, \n' \
  149. f'hard: {self.hard}, \n' \
  150. f'gumbel: {self.gumbel}, \n' \
  151. f'sum_assign={self.sum_assign}, \n' \
  152. f'gumbel_tau: {self.gumbel_tau}, \n' \
  153. f'assign_eps: {self.assign_eps}'
  154. class GroupingBlock(nn.Module):
  155. """Grouping Block to group similar segments together.
  156. Args:
  157. dim (int): Dimension of the input.
  158. out_dim (int): Dimension of the output.
  159. num_heads (int): Number of heads in the grouping attention.
  160. num_output_group (int): Number of output groups.
  161. norm_layer (nn.Module): Normalization layer to use.
  162. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  163. hard (bool): Whether to use hard or soft assignment. Default: True
  164. gumbel (bool): Whether to use gumbel softmax. Default: True
  165. sum_assign (bool): Whether to sum assignment or average. Default: False
  166. assign_eps (float): Epsilon to avoid divide by zero. Default: 1
  167. gum_tau (float): Temperature for gumbel softmax. Default: 1
  168. """
  169. def __init__(self,
  170. *,
  171. dim,
  172. out_dim,
  173. num_heads,
  174. num_group_token,
  175. num_output_group,
  176. norm_layer,
  177. mlp_ratio=(0.5, 4.0),
  178. hard=True,
  179. gumbel=True,
  180. sum_assign=False,
  181. assign_eps=1.,
  182. gumbel_tau=1.,
  183. attn_drop=0.,
  184. ):
  185. super(GroupingBlock, self).__init__()
  186. self.dim = dim
  187. self.hard = hard
  188. self.gumbel = gumbel
  189. self.sum_assign = sum_assign
  190. self.num_output_group = num_output_group
  191. # norm on group_tokens
  192. self.norm_tokens = norm_layer(dim)
  193. tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
  194. self.mlp_inter = Mlp(num_group_token, tokens_dim, num_output_group)
  195. self.norm_post_tokens = norm_layer(dim)
  196. # norm on x
  197. self.norm_x = norm_layer(dim)
  198. self.pre_assign_attn = CrossAttnBlock(
  199. dim=dim, num_heads=num_heads, mlp_ratio=4, qkv_bias=True, norm_layer=norm_layer, post_norm=True)
  200. self.assign = AssignAttention(
  201. dim=dim,
  202. num_heads=1,
  203. qkv_bias=True,
  204. hard=hard,
  205. gumbel=gumbel,
  206. gumbel_tau=gumbel_tau,
  207. sum_assign=sum_assign,
  208. assign_eps=assign_eps,
  209. attn_drop=attn_drop,
  210. )
  211. self.norm_new_x = norm_layer(dim)
  212. self.mlp_channels = Mlp(dim, channels_dim, out_dim)
  213. if out_dim is not None and dim != out_dim:
  214. self.reduction = nn.Sequential(norm_layer(dim), nn.Linear(dim, out_dim, bias=False))
  215. else:
  216. self.reduction = nn.Identity()
  217. def extra_repr(self):
  218. return f'hard={self.hard}, \n' \
  219. f'gumbel={self.gumbel}, \n' \
  220. f'sum_assign={self.sum_assign}, \n' \
  221. f'num_output_group={self.num_output_group}, \n '
  222. def project_group_token(self, group_tokens):
  223. """
  224. Args:
  225. group_tokens (torch.Tensor): group tokens, [B, S_1, C]
  226. inter_weight (torch.Tensor): [B, S_2, S_1], S_2 is the new number of
  227. group tokens, it's already softmaxed along dim=-1
  228. Returns:
  229. projected_group_tokens (torch.Tensor): [B, S_2, C]
  230. """
  231. # [B, S_2, C] <- [B, S_1, C]
  232. projected_group_tokens = self.mlp_inter(group_tokens.transpose(1, 2)).transpose(1, 2)
  233. projected_group_tokens = self.norm_post_tokens(projected_group_tokens)
  234. return projected_group_tokens
  235. def forward(self, x, group_tokens, return_attn=False):
  236. """
  237. Args:
  238. x (torch.Tensor): image tokens, [B, L, C]
  239. group_tokens (torch.Tensor): group tokens, [B, S_1, C]
  240. return_attn (bool): whether to return attention map
  241. Returns:
  242. new_x (torch.Tensor): [B, S_2, C], S_2 is the new number of
  243. group tokens
  244. """
  245. group_tokens = self.norm_tokens(group_tokens)
  246. x = self.norm_x(x)
  247. # [B, S_2, C]
  248. projected_group_tokens = self.project_group_token(group_tokens)
  249. projected_group_tokens = self.pre_assign_attn(projected_group_tokens, x)
  250. new_x, attn_dict = self.assign(projected_group_tokens, x, return_attn=return_attn)
  251. new_x += projected_group_tokens
  252. new_x = self.reduction(new_x) + self.mlp_channels(self.norm_new_x(new_x))
  253. return new_x, attn_dict
  254. class Attention(nn.Module):
  255. def __init__(self,
  256. dim,
  257. num_heads,
  258. out_dim=None,
  259. qkv_bias=False,
  260. qk_scale=None,
  261. attn_drop=0.,
  262. proj_drop=0.,
  263. qkv_fuse=False):
  264. super().__init__()
  265. if out_dim is None:
  266. out_dim = dim
  267. self.num_heads = num_heads
  268. head_dim = dim // num_heads
  269. self.scale = qk_scale or head_dim**-0.5
  270. self.qkv_fuse = qkv_fuse
  271. if qkv_fuse:
  272. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  273. else:
  274. self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
  275. self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
  276. self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
  277. self.attn_drop = nn.Dropout(attn_drop)
  278. self.proj = nn.Linear(dim, out_dim)
  279. self.proj_drop = nn.Dropout(proj_drop)
  280. def extra_repr(self):
  281. return f'num_heads={self.num_heads}, \n' \
  282. f'qkv_bias={self.scale}, \n' \
  283. f'qkv_fuse={self.qkv_fuse}'
  284. def forward(self, query, key=None, *, value=None, mask=None):
  285. if self.qkv_fuse:
  286. assert key is None
  287. assert value is None
  288. x = query
  289. B, N, C = x.shape
  290. S = N
  291. # [3, B, nh, N, C//nh]
  292. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  293. # [B, nh, N, C//nh]
  294. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  295. else:
  296. B, N, C = query.shape
  297. if key is None:
  298. key = query
  299. if value is None:
  300. value = key
  301. S = key.size(1)
  302. # [B, nh, N, C//nh]
  303. q = rearrange(self.q_proj(query), 'b n (h c)-> b h n c', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
  304. # [B, nh, S, C//nh]
  305. k = rearrange(self.k_proj(key), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
  306. # [B, nh, S, C//nh]
  307. v = rearrange(self.v_proj(value), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads)
  308. # [B, nh, N, S]
  309. attn = (q @ k.transpose(-2, -1)) * self.scale
  310. attn = attn.softmax(dim=-1)
  311. attn = self.attn_drop(attn)
  312. assert attn.shape == (B, self.num_heads, N, S)
  313. # [B, nh, N, C//nh] -> [B, N, C]
  314. # out = (attn @ v).transpose(1, 2).reshape(B, N, C)
  315. out = rearrange(attn @ v, 'b h n c -> b n (h c)', h=self.num_heads, b=B, n=N, c=C // self.num_heads)
  316. out = self.proj(out)
  317. out = self.proj_drop(out)
  318. return out
  319. class CrossAttnBlock(nn.Module):
  320. def __init__(self,
  321. dim,
  322. num_heads,
  323. mlp_ratio=4.,
  324. qkv_bias=False,
  325. qk_scale=None,
  326. drop=0.,
  327. attn_drop=0.,
  328. drop_path=0.,
  329. act_layer=nn.GELU,
  330. norm_layer=nn.LayerNorm,
  331. post_norm=False):
  332. super().__init__()
  333. if post_norm:
  334. self.norm_post = norm_layer(dim)
  335. self.norm_q = nn.Identity()
  336. self.norm_k = nn.Identity()
  337. else:
  338. self.norm_q = norm_layer(dim)
  339. self.norm_k = norm_layer(dim)
  340. self.norm_post = nn.Identity()
  341. self.attn = Attention(
  342. dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  343. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  344. self.norm2 = norm_layer(dim)
  345. mlp_hidden_dim = int(dim * mlp_ratio)
  346. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  347. def forward(self, query, key, *, mask=None):
  348. x = query
  349. x = x + self.drop_path(self.attn(self.norm_q(query), self.norm_k(key), mask=mask))
  350. x = x + self.drop_path(self.mlp(self.norm2(x)))
  351. x = self.norm_post(x)
  352. return x
  353. class AttnBlock(nn.Module):
  354. def __init__(self,
  355. dim,
  356. num_heads,
  357. mlp_ratio=4.,
  358. qkv_bias=False,
  359. qk_scale=None,
  360. drop=0.,
  361. attn_drop=0.,
  362. drop_path=0.,
  363. act_layer=nn.GELU,
  364. norm_layer=nn.LayerNorm):
  365. super().__init__()
  366. self.norm1 = norm_layer(dim)
  367. self.attn = Attention(
  368. dim,
  369. num_heads=num_heads,
  370. qkv_bias=qkv_bias,
  371. qk_scale=qk_scale,
  372. attn_drop=attn_drop,
  373. proj_drop=drop,
  374. qkv_fuse=True)
  375. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  376. self.norm2 = norm_layer(dim)
  377. mlp_hidden_dim = int(dim * mlp_ratio)
  378. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  379. def forward(self, x, mask=None):
  380. x = x + self.drop_path(self.attn(self.norm1(x), mask=mask))
  381. x = x + self.drop_path(self.mlp(self.norm2(x)))
  382. return x
  383. class GroupingLayer(nn.Module):
  384. """A Transformer layer with Grouping Block for one stage.
  385. Args:
  386. dim (int): Number of input channels.
  387. num_input_token (int): Input resolution.
  388. depth (int): Number of blocks.
  389. num_heads (int): Number of attention heads.
  390. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  391. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  392. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  393. drop (float, optional): Dropout rate. Default: 0.0
  394. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  395. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  396. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  397. downsample (nn.Module | None, optional): Downsample layer at the end of the layer.
  398. In GroupViT setting, Grouping Block serves as the downsampling layer. Default: None
  399. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  400. group_projector (nn.Module | None, optional): Projector for the grouping layer. Default: None.
  401. zero_init_group_token (bool): Whether to initialize the grouping token to 0. Default: False.
  402. """
  403. def __init__(self,
  404. dim,
  405. num_input_token,
  406. depth,
  407. num_heads,
  408. num_group_token,
  409. mlp_ratio=4.,
  410. qkv_bias=True,
  411. qk_scale=None,
  412. drop=0.,
  413. attn_drop=0.,
  414. drop_path=0.,
  415. norm_layer=nn.LayerNorm,
  416. downsample=None,
  417. use_checkpoint=False,
  418. group_projector=None,
  419. zero_init_group_token=False,
  420. ):
  421. super().__init__()
  422. self.dim = dim
  423. self.input_length = num_input_token
  424. self.depth = depth
  425. self.use_checkpoint = use_checkpoint
  426. self.num_group_token = num_group_token
  427. if num_group_token > 0:
  428. self.group_token = nn.Parameter(torch.zeros(1, num_group_token, dim))
  429. if not zero_init_group_token:
  430. trunc_normal_(self.group_token, std=.02)
  431. else:
  432. self.group_token = None
  433. # build blocks
  434. self.depth = depth
  435. blocks = []
  436. for i in range(depth):
  437. blocks.append(
  438. AttnBlock(
  439. dim=dim,
  440. num_heads=num_heads,
  441. mlp_ratio=mlp_ratio,
  442. qkv_bias=qkv_bias,
  443. qk_scale=qk_scale,
  444. drop=drop,
  445. attn_drop=attn_drop,
  446. drop_path=drop_path[i],
  447. norm_layer=norm_layer))
  448. self.blocks = nn.ModuleList(blocks)
  449. self.downsample = downsample
  450. self.input_resolution = num_input_token
  451. self.use_checkpoint = use_checkpoint
  452. self.group_projector = group_projector
  453. @property
  454. def with_group_token(self):
  455. return self.group_token is not None
  456. def extra_repr(self):
  457. return f'dim={self.dim}, \n' \
  458. f'input_resolution={self.input_resolution}, \n' \
  459. f'depth={self.depth}, \n' \
  460. f'num_group_token={self.num_group_token}, \n'
  461. def split_x(self, x):
  462. if self.with_group_token:
  463. return x[:, :-self.num_group_token], x[:, -self.num_group_token:]
  464. else:
  465. return x, None
  466. def concat_x(self, x, group_token=None):
  467. if group_token is None:
  468. return x
  469. return torch.cat([x, group_token], dim=1)
  470. def forward(self, x, prev_group_token=None, return_attn=False):
  471. """
  472. Args:
  473. x (torch.Tensor): image tokens, [B, L, C]
  474. prev_group_token (torch.Tensor): group tokens, [B, S_1, C]
  475. return_attn (bool): whether to return attention maps
  476. """
  477. if self.with_group_token:
  478. group_token = self.group_token.expand(x.size(0), -1, -1)
  479. if self.group_projector is not None:
  480. group_token = group_token + self.group_projector(prev_group_token)
  481. else:
  482. group_token = None
  483. B, L, C = x.shape
  484. cat_x = self.concat_x(x, group_token)
  485. for blk_idx, blk in enumerate(self.blocks):
  486. if self.use_checkpoint:
  487. cat_x = checkpoint.checkpoint(blk, cat_x)
  488. else:
  489. cat_x = blk(cat_x)
  490. x, group_token = self.split_x(cat_x)
  491. attn_dict = None
  492. if self.downsample is not None:
  493. x, attn_dict = self.downsample(x, group_token, return_attn=return_attn)
  494. return x, group_token, attn_dict
  495. class PatchEmbed(nn.Module):
  496. """Image to Patch Embedding."""
  497. def __init__(self, img_size=224, kernel_size=7, stride=4, padding=2, in_chans=3, embed_dim=96, norm_layer=None):
  498. super().__init__()
  499. img_size = to_2tuple(img_size)
  500. kernel_size = to_2tuple(kernel_size)
  501. stride = to_2tuple(stride)
  502. padding = to_2tuple(padding)
  503. self.img_size = img_size
  504. self.patches_resolution = (
  505. int((img_size[1] + 2 * padding[1] - kernel_size[1]) / stride[1] + 1),
  506. int((img_size[0] + 2 * padding[0] - kernel_size[0]) / stride[0] + 1),
  507. )
  508. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
  509. if norm_layer is not None:
  510. self.norm = norm_layer(embed_dim)
  511. else:
  512. self.norm = None
  513. @property
  514. def num_patches(self):
  515. return self.patches_resolution[1] * self.patches_resolution[0]
  516. def forward(self, x):
  517. B, C, H, W = x.shape
  518. if self.training:
  519. # FIXME look at relaxing size constraints
  520. assert H == self.img_size[0] and W == self.img_size[1], \
  521. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  522. x = self.proj(x)
  523. hw_shape = x.shape[2:]
  524. x = x.flatten(2).transpose(1, 2)
  525. if self.norm is not None:
  526. x = self.norm(x)
  527. return x, hw_shape
  528. @MODELS.register_module()
  529. class GroupViT(nn.Module):
  530. r""" Group Vision Transformer
  531. A PyTorch impl of : `GroupViT: Semantic Segmentation Emerges from Text Supervision` -
  532. https://arxiv.org/pdf/2202.11094.pdf
  533. Args:
  534. img_size (int | tuple[int]): Input image size. Default 224
  535. patch_size (int | tuple[int]): Patch size. Default: 4
  536. in_chans (int): Number of input image channels. Default: 3
  537. num_classes (int): Number of classes for classification head. Default: 0
  538. embed_dim (int): Patch embedding dimension. Default: 384
  539. embed_factors (list[int]): Embedding dim multipliers for each stage.
  540. depths (list[int]): Depth of each stage
  541. num_heads (list[int]): Number of heads for each stage
  542. num_group_tokens (list[int]): Number of group tokens for each stage
  543. num_output_group (list[int]): Number of output groups for each stage
  544. hard_assignment (bool): Whether to use hard assignment or not. Default: True
  545. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  546. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  547. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
  548. drop_rate (float): Dropout rate. Default: 0
  549. attn_drop_rate (float): Attention dropout rate. Default: 0
  550. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  551. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  552. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
  553. pos_embed_type (str): Type of positional embedding. Default: 'simple'
  554. freeze_patch_embed (bool): Whether to freeze patch embedding. Default: False
  555. """
  556. def __init__(self,
  557. img_size=224,
  558. patch_size=16,
  559. in_chans=3,
  560. num_classes=0,
  561. embed_dim=384,
  562. embed_factors=[1, 1, 1],
  563. depths=[6, 3, 3],
  564. num_heads=[6, 6, 6],
  565. num_group_tokens=[64, 8, 0],
  566. num_output_groups=[64, 8],
  567. hard_assignment=True,
  568. mlp_ratio=4.,
  569. qkv_bias=True,
  570. qk_scale=None,
  571. drop_rate=0.,
  572. attn_drop_rate=0.,
  573. drop_path_rate=0.1,
  574. patch_norm=True,
  575. use_checkpoint=False,
  576. pos_embed_type='simple',
  577. freeze_patch_embed=False,
  578. imgnet_pretrained=None,
  579. fixed=False,
  580. imgnet_pretrained_checkpoint='/mnt/petrelfs/xujilan/checkpoints/dino_vitbase16_pretrain.pth',
  581. ):
  582. super().__init__()
  583. assert patch_size in [4, 8, 16]
  584. self.num_classes = num_classes
  585. assert len(embed_factors) == len(depths) == len(num_group_tokens)
  586. assert all(_ == 0 for _ in num_heads) or len(depths) == len(num_heads)
  587. # assert len(depths) - 1 == len(num_output_groups)
  588. self.num_layers = len(depths)
  589. self.embed_dim = embed_dim
  590. self.patch_norm = patch_norm
  591. self.num_features = int(embed_dim * embed_factors[len(depths) - 1])
  592. self.mlp_ratio = mlp_ratio
  593. self.qkv_bias = qkv_bias
  594. self.qk_scale = qk_scale
  595. self.drop_rate = drop_rate
  596. self.attn_drop_rate = attn_drop_rate
  597. self.drop_path_rate = drop_path_rate
  598. self.num_group_tokens = num_group_tokens
  599. self.num_output_groups = num_output_groups
  600. self.pos_embed_type = pos_embed_type
  601. assert pos_embed_type in ['simple', 'fourier']
  602. self.freeze_backbone = fixed
  603. norm_layer = nn.LayerNorm
  604. # split image into non-overlapping patches
  605. self.patch_embed = PatchEmbed(
  606. img_size=img_size,
  607. kernel_size=patch_size,
  608. stride=patch_size,
  609. padding=0,
  610. in_chans=in_chans,
  611. embed_dim=embed_dim,
  612. norm_layer=norm_layer if self.patch_norm else None)
  613. num_patches = self.patch_embed.num_patches
  614. patches_resolution = self.patch_embed.patches_resolution
  615. self.patches_resolution = patches_resolution
  616. self.avgpool = nn.AdaptiveAvgPool1d(1)
  617. if pos_embed_type == 'simple':
  618. self.pos_embed = self.build_simple_position_embedding()
  619. elif pos_embed_type == 'fourier':
  620. self.pos_embed = self.build_2d_sincos_position_embedding()
  621. else:
  622. raise ValueError
  623. if freeze_patch_embed:
  624. for param in self.patch_embed.parameters():
  625. param.requires_grad = False
  626. self.pos_embed.requires_grad = False
  627. self.pos_drop = nn.Dropout(p=drop_rate)
  628. # stochastic depth
  629. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
  630. num_input_token = num_patches
  631. num_output_token = num_input_token
  632. # build layers
  633. self.layers = nn.ModuleList()
  634. for i_layer in range(self.num_layers):
  635. dim = int(embed_dim * embed_factors[i_layer])
  636. downsample = None
  637. if i_layer < self.num_layers - 1:
  638. out_dim = embed_dim * embed_factors[i_layer + 1]
  639. downsample = GroupingBlock(
  640. dim=dim,
  641. out_dim=out_dim,
  642. num_heads=num_heads[i_layer],
  643. num_group_token=num_group_tokens[i_layer],
  644. num_output_group=num_output_groups[i_layer],
  645. norm_layer=norm_layer,
  646. hard=hard_assignment,
  647. gumbel=hard_assignment,
  648. attn_drop=attn_drop_rate,
  649. )
  650. num_output_token = num_output_groups[i_layer]
  651. if i_layer > 0 and num_group_tokens[i_layer] > 0:
  652. prev_dim = int(embed_dim * embed_factors[i_layer - 1])
  653. group_projector = nn.Sequential(
  654. norm_layer(prev_dim),
  655. MixerMlp(num_group_tokens[i_layer - 1], prev_dim // 2, num_group_tokens[i_layer]))
  656. if dim != prev_dim:
  657. group_projector = nn.Sequential(group_projector, norm_layer(prev_dim),
  658. nn.Linear(prev_dim, dim, bias=False))
  659. else:
  660. group_projector = None
  661. layer = GroupingLayer(
  662. dim=dim,
  663. num_input_token=num_input_token,
  664. depth=depths[i_layer],
  665. num_heads=num_heads[i_layer],
  666. num_group_token=num_group_tokens[i_layer],
  667. mlp_ratio=self.mlp_ratio,
  668. qkv_bias=qkv_bias,
  669. qk_scale=qk_scale,
  670. drop=drop_rate,
  671. attn_drop=attn_drop_rate,
  672. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  673. norm_layer=norm_layer,
  674. downsample=downsample,
  675. use_checkpoint=use_checkpoint,
  676. group_projector=group_projector,
  677. # only zero init group token if we have a projection
  678. zero_init_group_token=group_projector is not None,
  679. )
  680. self.layers.append(layer)
  681. if i_layer < self.num_layers - 1:
  682. num_input_token = num_output_token
  683. self.norm = norm_layer(self.num_features)
  684. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  685. self.apply(self._init_weights)
  686. self.imgnet_pretrained = imgnet_pretrained
  687. self.proj = None
  688. if imgnet_pretrained is not None:
  689. ### add cls_token to enable params loading ###
  690. self.pos_embed = self.build_simple_position_embedding_with_cls_token()
  691. self.init_backbone_with_imagenet_weights(imgnet_pretrained_checkpoint)
  692. ### drop cls_token ###
  693. self.pos_embed = nn.Parameter(self.pos_embed[0, 1:])
  694. def init_backbone_with_imagenet_weights(self, checkpoint_path):
  695. if self.imgnet_pretrained == 'imgnet':
  696. from timm.models import vit_base_patch16_224
  697. net = vit_base_patch16_224(pretrained=True)
  698. state_dict = net.state_dict()
  699. elif self.imgnet_pretrained in ['dino', 'dinob8', 'dinos16', 'dinos8']:
  700. state_dict = torch.load(checkpoint_path)
  701. elif self.imgnet_pretrained == 'clip':
  702. clip_model, _ = clip.load('ViT-B/16', device='cuda', jit=False)
  703. state_dict = clip_model.visual.state_dict()
  704. print('Initializing ImageNet-pretrained weights')
  705. print('$' * 100)
  706. newdict = {}
  707. if self.imgnet_pretrained != 'clip':
  708. if self.num_layers == 2:
  709. for kk, vv in state_dict.items():
  710. newkey = kk
  711. if kk.startswith('blocks.'):
  712. layerid = int(kk.split('.')[1])
  713. if 0 <= layerid < 6:
  714. newkey = 'layers.0.' + kk
  715. elif 6 <= layerid < 12:
  716. old_prefix = 'blocks.' + str(layerid) + '.'
  717. new_prefix = 'blocks.' + str(layerid - 6) + '.'
  718. suffix = kk.split(old_prefix)[1]
  719. newkey = 'layers.1.' + new_prefix + suffix
  720. newdict[newkey] = vv
  721. elif self.num_layers == 3:
  722. for kk, vv in state_dict.items():
  723. newkey = kk
  724. if kk.startswith('blocks.'):
  725. layerid = int(kk.split('.')[1])
  726. if 0 <= layerid < 6:
  727. newkey = 'layers.0.' + kk
  728. elif 6 <= layerid < 9:
  729. old_prefix = 'blocks.' + str(layerid) + '.'
  730. new_prefix = 'blocks.' + str(layerid - 6) + '.'
  731. suffix = kk.split(old_prefix)[1]
  732. newkey = 'layers.1.' + new_prefix + suffix
  733. elif 9 <= layerid < 12:
  734. old_prefix = 'blocks.' + str(layerid) + '.'
  735. new_prefix = 'blocks.' + str(layerid - 9) + '.'
  736. suffix = kk.split(old_prefix)[1]
  737. newkey = 'layers.2.' + new_prefix + suffix
  738. newdict[newkey] = vv
  739. else:
  740. for kk, vv in state_dict.items():
  741. newkey = kk
  742. newkey = newkey.replace('transformer.','')
  743. newkey = newkey.replace('resblocks', 'blocks')
  744. newkey = newkey.replace('attn.in_proj_weight','attn.qkv.weight')
  745. newkey = newkey.replace('attn.in_proj_bias','attn.qkv.bias')
  746. newkey = newkey.replace('attn.out_proj.weight','attn.proj.weight')
  747. newkey = newkey.replace('attn.out_proj.bias','attn.proj.bias')
  748. newkey = newkey.replace('ln_1.weight','norm1.weight')
  749. newkey = newkey.replace('ln_1.bias','norm1.bias')
  750. newkey = newkey.replace('ln_2.weight','norm2.weight')
  751. newkey = newkey.replace('ln_2.bias','norm2.bias')
  752. newkey = newkey.replace('mlp.c_fc.weight','mlp.fc1.weight')
  753. newkey = newkey.replace('mlp.c_fc.bias', 'mlp.fc1.bias')
  754. newkey = newkey.replace('mlp.c_proj.weight','mlp.fc2.weight')
  755. newkey = newkey.replace('mlp.c_proj.bias', 'mlp.fc2.bias')
  756. newkey = newkey.replace('ln_post.weight', 'norm.weight')
  757. newkey = newkey.replace('ln_post.bias', 'norm.bias')
  758. newkey = newkey.replace('positional_embedding', 'pos_embed')
  759. newkey = newkey.replace('conv1.weight', 'patch_embed.proj.weight')
  760. kk = newkey
  761. if newkey == 'proj':
  762. self.proj = nn.Parameter(torch.zeros(vv.shape[0], vv.shape[1]))
  763. if newkey == 'pos_embed':
  764. vv = vv.unsqueeze(0)
  765. if kk.startswith('blocks.'):
  766. layerid = int(kk.split('.')[1])
  767. if 0 <= layerid < 6:
  768. newkey = 'layers.0.' + kk
  769. elif 6 <= layerid < 12:
  770. old_prefix = 'blocks.' + str(layerid) + '.'
  771. new_prefix = 'blocks.' + str(layerid - 6) + '.'
  772. suffix = kk.split(old_prefix)[1]
  773. newkey = 'layers.1.' + new_prefix + suffix
  774. newdict[newkey] = vv
  775. ### init all self-attn/pos_embed/patch_embed layers ###
  776. msg = self.load_state_dict(newdict, strict=False)
  777. if self.freeze_backbone:
  778. for n, p in self.named_parameters():
  779. if n in newdict:
  780. p.requires_grad = False
  781. print('Freezing parameter: ', n)
  782. print(msg)
  783. print('$' * 100)
  784. def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
  785. if self.pos_embed_type == 'simple' and 'pos_embed' in state_dict:
  786. load_pos_embed = state_dict['pos_embed']
  787. pos_embed = self.pos_embed
  788. if load_pos_embed.shape != pos_embed.shape:
  789. H_new = int(self.patch_embed.num_patches**0.5)
  790. W_new = H_new
  791. H_ori = int(load_pos_embed.shape[1]**0.5)
  792. W_ori = H_ori
  793. load_pos_embed = F.interpolate(
  794. rearrange(load_pos_embed, 'b (h w) c -> b c h w', h=H_ori, w=W_ori, b=1),
  795. size=(H_new, W_new),
  796. mode='bicubic',
  797. align_corners=False)
  798. load_pos_embed = rearrange(load_pos_embed, 'b c h w -> b (h w) c', h=H_new, w=W_new)
  799. state_dict['pos_embed'] = load_pos_embed
  800. return super().load_state_dict(state_dict, strict)
  801. def build_simple_position_embedding(self):
  802. pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches, self.embed_dim))
  803. trunc_normal_(pos_embed, std=.02)
  804. return pos_embed
  805. def build_simple_position_embedding_with_cls_token(self):
  806. pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim))
  807. trunc_normal_(pos_embed, std=.02)
  808. return pos_embed
  809. def build_2d_sincos_position_embedding(self, temperature=10000.):
  810. h, w = self.patch_embed.patches_resolution
  811. grid_w = torch.arange(w, dtype=torch.float32)
  812. grid_h = torch.arange(h, dtype=torch.float32)
  813. grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
  814. assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  815. pos_dim = self.embed_dim // 4
  816. omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
  817. omega = 1. / (temperature**omega)
  818. out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
  819. out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
  820. pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
  821. pos_embed = nn.Parameter(pos_emb)
  822. pos_embed.requires_grad = False
  823. return pos_embed
  824. @property
  825. def width(self):
  826. return self.num_features
  827. def _init_weights(self, m):
  828. if isinstance(m, nn.Linear):
  829. trunc_normal_(m.weight, std=.02)
  830. if isinstance(m, nn.Linear) and m.bias is not None:
  831. nn.init.constant_(m.bias, 0)
  832. elif isinstance(m, nn.LayerNorm):
  833. nn.init.constant_(m.bias, 0)
  834. nn.init.constant_(m.weight, 1.0)
  835. def get_pos_embed(self, B, H, W):
  836. if self.training:
  837. return self.pos_embed
  838. pos_embed = self.pos_embed
  839. pos_embed = interpolate_pos_encoding(pos_embed, H, W)
  840. return pos_embed
  841. def forward_features(self, x, *, return_attn=False):
  842. B = x.shape[0]
  843. x, hw_shape = self.patch_embed(x)
  844. x = x + self.get_pos_embed(B, *hw_shape)
  845. x = self.pos_drop(x)
  846. group_token = None
  847. attn_dict_list = []
  848. for i, layer in enumerate(self.layers):
  849. x, group_token, attn_dict = layer(x, group_token, return_attn=return_attn)
  850. if attn_dict is not None:
  851. attn_dict_list.append(attn_dict)
  852. x = self.norm(x)
  853. return x, group_token, attn_dict_list
  854. def forward_image_head(self, x):
  855. """
  856. Args:
  857. x: shape [B, L, C]
  858. Returns:
  859. """
  860. # [B, L, C]
  861. x = self.avgpool(x.transpose(1, 2)) # B C 1
  862. x = torch.flatten(x, 1)
  863. x = self.head(x) if self.proj is None else x @ self.proj
  864. return x
  865. def forward(self, x, *, return_feat=False, return_attn=False, as_dict=False, sampled_noun_indices=None):
  866. x, group_token, attn_dicts = self.forward_features(x, return_attn=return_attn)
  867. x_feat = x if return_feat else None
  868. outs = Result(as_dict=as_dict)
  869. outs.append(self.forward_image_head(x), name='x')
  870. if return_feat:
  871. outs.append(x_feat if self.proj is None else x_feat @ self.proj, name='feat')
  872. if return_attn:
  873. outs.append(attn_dicts, name='attn_dicts')
  874. return outs.as_return()