group_vit_seg.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. # Modified by Jilan Xu
  14. # -------------------------------------------------------------------------
  15. import os.path as osp
  16. import matplotlib.pyplot as plt
  17. import mmcv
  18. import numpy as np
  19. import torch
  20. import torch.nn.functional as F
  21. from einops import rearrange
  22. from mmseg.models import EncoderDecoder
  23. from PIL import Image
  24. from utils import get_logger
  25. import cv2
  26. GROUP_PALETTE = np.loadtxt(osp.join(osp.dirname(osp.abspath(__file__)), 'group_palette.txt'), dtype=np.uint8)[:, ::-1]
  27. from ipdb import set_trace
  28. def resize_attn_map(attentions, h, w, align_corners=False):
  29. """
  30. Args:
  31. attentions: shape [B, num_head, H*W, groups]
  32. h:
  33. w:
  34. Returns:
  35. attentions: shape [B, num_head, h, w, groups]
  36. """
  37. scale = (h * w // attentions.shape[2])**0.5
  38. if h > w:
  39. w_featmap = w // int(np.round(scale))
  40. h_featmap = attentions.shape[2] // w_featmap
  41. else:
  42. h_featmap = h // int(np.round(scale))
  43. w_featmap = attentions.shape[2] // h_featmap
  44. assert attentions.shape[
  45. 2] == h_featmap * w_featmap, f'{attentions.shape[2]} = {h_featmap} x {w_featmap}, h={h}, w={w}'
  46. bs = attentions.shape[0]
  47. nh = attentions.shape[1] # number of head
  48. groups = attentions.shape[3] # number of group token
  49. # [bs, nh, h*w, groups] -> [bs*nh, groups, h, w]
  50. attentions = rearrange(
  51. attentions, 'bs nh (h w) c -> (bs nh) c h w', bs=bs, nh=nh, h=h_featmap, w=w_featmap, c=groups)
  52. attentions = F.interpolate(attentions, size=(h, w), mode='bilinear', align_corners=align_corners)
  53. # [bs*nh, groups, h, w] -> [bs, nh, h*w, groups]
  54. attentions = rearrange(attentions, '(bs nh) c h w -> bs nh h w c', bs=bs, nh=nh, h=h, w=w, c=groups)
  55. return attentions
  56. def top_groups(attn_map, k):
  57. """
  58. Args:
  59. attn_map: (B, H, W, G)
  60. k: int
  61. Return:
  62. (B, H, W, k)
  63. """
  64. attn_map = attn_map.clone()
  65. for i in range(attn_map.size(0)):
  66. # [H*W, G]
  67. flatten_map = rearrange(attn_map[i], 'h w g -> (h w) g')
  68. kept_mat = torch.zeros(flatten_map.shape[0], device=flatten_map.device, dtype=torch.bool)
  69. area_per_group = flatten_map.sum(dim=0)
  70. top_group_idx = area_per_group.topk(k=k).indices.cpu().numpy().tolist()
  71. for group_idx in top_group_idx:
  72. kept_mat[flatten_map.argmax(dim=-1) == group_idx] = True
  73. # [H, W, 2]
  74. coords = torch.stack(
  75. torch.meshgrid(
  76. torch.arange(attn_map[i].shape[0], device=attn_map[i].device, dtype=attn_map[i].dtype),
  77. torch.arange(attn_map[i].shape[1], device=attn_map[i].device, dtype=attn_map[i].dtype)),
  78. dim=-1)
  79. coords = rearrange(coords, 'h w c -> (h w) c')
  80. # calculate distance between each pair of points
  81. # [non_kept, kept]
  82. dist_mat = torch.sum((coords[~kept_mat].unsqueeze(1) - coords[kept_mat].unsqueeze(0))**2, dim=-1)
  83. flatten_map[~kept_mat] = flatten_map[kept_mat.nonzero(as_tuple=True)[0][dist_mat.argmin(dim=-1)]]
  84. attn_map[i] = flatten_map.reshape_as(attn_map[i])
  85. return attn_map
  86. def seg2coord(seg_map):
  87. """
  88. Args:
  89. seg_map (np.ndarray): (H, W)
  90. Return:
  91. dict(group_id -> (x, y))
  92. """
  93. h, w = seg_map.shape
  94. # [h ,w, 2]
  95. coords = np.stack(np.meshgrid(np.arange(h), np.arange(w), indexing='ij'), axis=-1)
  96. labels = np.unique(seg_map)
  97. coord_map = {}
  98. for label in labels:
  99. coord_map[label] = coords[seg_map == label].mean(axis=0)
  100. return coord_map
  101. class GroupViTSegInference(EncoderDecoder):
  102. def __init__(self, model, text_embedding, with_bg, test_cfg=dict(mode='whole', bg_thresh=.95)):
  103. super(EncoderDecoder, self).__init__()
  104. if not isinstance(test_cfg, mmcv.Config):
  105. test_cfg = mmcv.Config(test_cfg)
  106. self.test_cfg = test_cfg
  107. self.model = model
  108. # [N, C]
  109. self.register_buffer('text_embedding', text_embedding)
  110. self.with_bg = with_bg
  111. self.bg_thresh = test_cfg['bg_thresh']
  112. if self.with_bg:
  113. self.num_classes = len(text_embedding) + 1
  114. else:
  115. self.num_classes = len(text_embedding)
  116. self.align_corners = False
  117. logger = get_logger()
  118. logger.info(
  119. f'Building GroupViTSegInference with {self.num_classes} classes, test_cfg={test_cfg}, with_bg={with_bg}')
  120. def forward_train(self, img, img_metas, gt_semantic_seg):
  121. raise NotImplementedError
  122. def get_attn_maps(self, img, return_onehot=False, rescale=False):
  123. """
  124. Args:
  125. img: [B, C, H, W]
  126. Returns:
  127. attn_maps: list[Tensor], attention map of shape [B, H, W, groups]
  128. """
  129. results = self.model.img_encoder(img, return_attn=True, as_dict=True)
  130. attn_maps = []
  131. with torch.no_grad():
  132. prev_attn_masks = None
  133. for idx, attn_dict in enumerate(results['attn_dicts']):
  134. if attn_dict is None:
  135. # changed doesn't have to be like this
  136. # assert idx == len(results['attn_dicts']) - 1, 'only last layer can be None'
  137. continue
  138. # [B, G, HxW]
  139. # B: batch size (1), G: number of group token
  140. attn_masks = attn_dict['soft']
  141. # [B, nH, G, HxW] -> [B, nH, HxW, G]
  142. attn_masks = rearrange(attn_masks, 'b h g n -> b h n g')
  143. if prev_attn_masks is None:
  144. prev_attn_masks = attn_masks
  145. else:
  146. prev_attn_masks = prev_attn_masks @ attn_masks
  147. # [B, nH, HxW, G] -> [B, nH, H, W, G]
  148. attn_maps.append(resize_attn_map(prev_attn_masks, *img.shape[-2:]))
  149. for i in range(len(attn_maps)):
  150. attn_map = attn_maps[i]
  151. # [B, nh, H, W, G]
  152. assert attn_map.shape[1] == 1
  153. # [B, H, W, G]
  154. attn_map = attn_map.squeeze(1)
  155. if rescale:
  156. attn_map = rearrange(attn_map, 'b h w g -> b g h w')
  157. attn_map = F.interpolate(
  158. attn_map, size=img.shape[2:], mode='bilinear', align_corners=self.align_corners)
  159. attn_map = rearrange(attn_map, 'b g h w -> b h w g')
  160. if return_onehot:
  161. # [B, H, W, G]
  162. attn_map = F.one_hot(attn_map.argmax(dim=-1), num_classes=attn_map.shape[-1]).to(dtype=attn_map.dtype)
  163. attn_maps[i] = attn_map
  164. return attn_maps
  165. def encode_decode(self, img, img_metas):
  166. """Encode images with backbone and decode into a semantic segmentation
  167. map of the same size as input."""
  168. assert img.shape[0] == 1, 'batch size must be 1'
  169. # [B, C, H, W], get the last one only
  170. attn_map = self.get_attn_maps(img, rescale=True)[-1]
  171. # [H, W, G], select batch idx 0
  172. attn_map = attn_map[0]
  173. img_outs = self.model.encode_image(img, return_feat=True, as_dict=True)
  174. # [B, L, C] -> [L, C]
  175. grouped_img_tokens = img_outs['image_feat'].squeeze(0)
  176. img_avg_feat = img_outs['image_x']
  177. # [G, C]
  178. grouped_img_tokens = F.normalize(grouped_img_tokens, dim=-1)
  179. img_avg_feat = F.normalize(img_avg_feat, dim=-1)
  180. # [H, W, G]
  181. onehot_attn_map = F.one_hot(attn_map.argmax(dim=-1), num_classes=attn_map.shape[-1]).to(dtype=attn_map.dtype)
  182. num_fg_classes = self.text_embedding.shape[0]
  183. class_offset = 1 if self.with_bg else 0
  184. text_tokens = self.text_embedding
  185. num_classes = num_fg_classes + class_offset
  186. logit_scale = torch.clamp(self.model.logit_scale.exp(), max=100)
  187. # [G, N]
  188. group_affinity_mat = (grouped_img_tokens @ text_tokens.T) * logit_scale
  189. pre_group_affinity_mat = F.softmax(group_affinity_mat, dim=-1)
  190. avg_affinity_mat = (img_avg_feat @ text_tokens.T) * logit_scale
  191. avg_affinity_mat = F.softmax(avg_affinity_mat, dim=-1)
  192. affinity_mask = torch.zeros_like(avg_affinity_mat)
  193. avg_affinity_topk = avg_affinity_mat.topk(dim=-1, k=min(5, num_fg_classes))
  194. affinity_mask.scatter_add_(
  195. dim=-1, index=avg_affinity_topk.indices, src=torch.ones_like(avg_affinity_topk.values))
  196. group_affinity_mat.masked_fill_(~affinity_mask.bool(), float('-inf'))
  197. group_affinity_mat = F.softmax(group_affinity_mat, dim=-1)
  198. # TODO: check if necessary
  199. group_affinity_mat *= pre_group_affinity_mat
  200. pred_logits = torch.zeros(num_classes, *attn_map.shape[:2], device=img.device, dtype=img.dtype)
  201. pred_logits[class_offset:] = rearrange(onehot_attn_map @ group_affinity_mat, 'h w c -> c h w')
  202. if self.with_bg:
  203. bg_thresh = min(self.bg_thresh, group_affinity_mat.max().item())
  204. pred_logits[0, (onehot_attn_map @ group_affinity_mat).max(dim=-1).values < bg_thresh] = 1
  205. return pred_logits.unsqueeze(0)
  206. def blend_result(self, img, result, palette=None, out_file=None, opacity=0.5, with_bg=False):
  207. img = mmcv.imread(img)
  208. img = img.copy()
  209. seg = result[0]
  210. if palette is None:
  211. palette = self.PALETTE
  212. palette = np.array(palette)
  213. assert palette.shape[1] == 3, palette.shape
  214. assert len(palette.shape) == 2
  215. assert 0 < opacity <= 1.0
  216. color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
  217. for label, color in enumerate(palette):
  218. color_seg[seg == label, :] = color
  219. # convert to BGR
  220. color_seg = color_seg[..., ::-1]
  221. if with_bg:
  222. fg_mask = seg != 0
  223. img[fg_mask] = img[fg_mask] * (1 - opacity) + color_seg[fg_mask] * opacity
  224. else:
  225. img = img * (1 - opacity) + color_seg * opacity
  226. img = img.astype(np.uint8)
  227. if out_file is not None:
  228. mmcv.imwrite(img, out_file)
  229. return img
  230. def show_result(self, img_show, img_tensor, result, out_file, vis_mode='pred'):
  231. print('current vis mode: ', vis_mode)
  232. assert vis_mode in [
  233. 'input', 'pred', 'input_pred', 'all_groups', 'first_group', 'final_group', 'input_pred_label', 'mask',
  234. ], vis_mode
  235. if vis_mode == 'input':
  236. mmcv.imwrite(img_show, out_file)
  237. elif vis_mode == 'pred':
  238. output = Image.fromarray(result[0].astype(np.uint8)).convert('P')
  239. output.putpalette(np.array(self.PALETTE).astype(np.uint8))
  240. mmcv.mkdir_or_exist(osp.dirname(out_file))
  241. output.save(out_file.replace('.jpg', '.png'))
  242. elif vis_mode == 'input_pred':
  243. self.blend_result(img=img_show, result=result, out_file=out_file, opacity=0.5, with_bg=self.with_bg)
  244. elif vis_mode == 'input_pred_label':
  245. labels = np.unique(result[0])
  246. coord_map = seg2coord(result[0])
  247. # reference: https://github.com/open-mmlab/mmdetection/blob/ff9bc39913cb3ff5dde79d3933add7dc2561bab7/mmdet/models/detectors/base.py#L271 # noqa
  248. blended_img = self.blend_result(
  249. img=img_show, result=result, out_file=None, opacity=0.5, with_bg=self.with_bg)
  250. blended_img = mmcv.bgr2rgb(blended_img)
  251. width, height = img_show.shape[1], img_show.shape[0]
  252. EPS = 1e-2
  253. fig = plt.figure(frameon=False)
  254. canvas = fig.canvas
  255. dpi = fig.get_dpi()
  256. fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)
  257. # remove white edges by set subplot margin
  258. plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
  259. ax = plt.gca()
  260. ax.axis('off')
  261. for i, label in enumerate(labels):
  262. if self.with_bg and label == 0:
  263. continue
  264. center = coord_map[label].astype(np.int32)
  265. label_text = self.CLASSES[label]
  266. ax.text(
  267. center[1],
  268. center[0],
  269. f'{label_text}',
  270. bbox={
  271. 'facecolor': 'black',
  272. 'alpha': 0.5,
  273. 'pad': 0.7,
  274. 'edgecolor': 'none'
  275. },
  276. color='orangered',
  277. fontsize=16,
  278. verticalalignment='top',
  279. horizontalalignment='left')
  280. plt.imshow(blended_img)
  281. stream, _ = canvas.print_to_buffer()
  282. buffer = np.frombuffer(stream, dtype='uint8')
  283. img_rgba = buffer.reshape(height, width, 4)
  284. rgb, alpha = np.split(img_rgba, [3], axis=2)
  285. img = rgb.astype('uint8')
  286. img = mmcv.rgb2bgr(img)
  287. mmcv.imwrite(img, out_file)
  288. plt.close()
  289. elif vis_mode == 'all_groups' or vis_mode == 'final_group' or vis_mode == 'first_group':
  290. attn_map_list = self.get_attn_maps(img_tensor)
  291. assert len(attn_map_list) in [1, 2]
  292. # only show 16 groups for the first stage
  293. # if len(attn_map_list) == 2:
  294. # attn_map_list[0] = top_groups(attn_map_list[0], k=16)
  295. num_groups = [attn_map_list[layer_idx].shape[-1] for layer_idx in range(len(attn_map_list))]
  296. for layer_idx, attn_map in enumerate(attn_map_list):
  297. if vis_mode == 'first_group' and layer_idx != 0:
  298. continue
  299. if vis_mode == 'final_group' and layer_idx != len(attn_map_list) - 1:
  300. continue
  301. attn_map = rearrange(attn_map, 'b h w g -> b g h w')
  302. attn_map = F.interpolate(
  303. attn_map, size=img_show.shape[:2], mode='bilinear', align_corners=self.align_corners)
  304. group_result = attn_map.argmax(dim=1).cpu().numpy()
  305. if vis_mode == 'all_groups':
  306. layer_out_file = out_file.replace(
  307. osp.splitext(out_file)[-1], f'_layer{layer_idx}{osp.splitext(out_file)[-1]}')
  308. else:
  309. layer_out_file = out_file
  310. self.blend_result(
  311. img=img_show,
  312. result=group_result,
  313. palette=GROUP_PALETTE[sum(num_groups[:layer_idx]):sum(num_groups[:layer_idx + 1])],
  314. out_file=layer_out_file,
  315. opacity=0.5)
  316. elif vis_mode == 'mask':
  317. mask = result[0]
  318. mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
  319. # mask.putpalette(np.array(self.PALETTE).astype(np.uint8))
  320. mmcv.mkdir_or_exist(osp.dirname(out_file))
  321. mask.save(out_file.replace('.jpg', '.png'))
  322. else:
  323. raise ValueError(f'Unknown vis_type: {vis_mode}')