group_vit_seg.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # -------------------------------------------------------------------------
  10. import os.path as osp
  11. import matplotlib.pyplot as plt
  12. import mmcv
  13. import numpy as np
  14. import torch
  15. import torch.nn.functional as F
  16. from einops import rearrange
  17. from mmseg.models import EncoderDecoder
  18. from PIL import Image
  19. from utils import get_logger
  20. GROUP_PALETTE = np.loadtxt(osp.join(osp.dirname(osp.abspath(__file__)), 'group_palette.txt'), dtype=np.uint8)[:, ::-1]
  21. def resize_attn_map(attentions, h, w, align_corners=False):
  22. """
  23. Args:
  24. attentions: shape [B, num_head, H*W, groups]
  25. h:
  26. w:
  27. Returns:
  28. attentions: shape [B, num_head, h, w, groups]
  29. """
  30. scale = (h * w // attentions.shape[2])**0.5
  31. if h > w:
  32. w_featmap = w // int(np.round(scale))
  33. h_featmap = attentions.shape[2] // w_featmap
  34. else:
  35. h_featmap = h // int(np.round(scale))
  36. w_featmap = attentions.shape[2] // h_featmap
  37. assert attentions.shape[
  38. 2] == h_featmap * w_featmap, f'{attentions.shape[2]} = {h_featmap} x {w_featmap}, h={h}, w={w}'
  39. bs = attentions.shape[0]
  40. nh = attentions.shape[1] # number of head
  41. groups = attentions.shape[3] # number of group token
  42. # [bs, nh, h*w, groups] -> [bs*nh, groups, h, w]
  43. attentions = rearrange(
  44. attentions, 'bs nh (h w) c -> (bs nh) c h w', bs=bs, nh=nh, h=h_featmap, w=w_featmap, c=groups)
  45. attentions = F.interpolate(attentions, size=(h, w), mode='bilinear', align_corners=align_corners)
  46. # [bs*nh, groups, h, w] -> [bs, nh, h*w, groups]
  47. attentions = rearrange(attentions, '(bs nh) c h w -> bs nh h w c', bs=bs, nh=nh, h=h, w=w, c=groups)
  48. return attentions
  49. def top_groups(attn_map, k):
  50. """
  51. Args:
  52. attn_map: (B, H, W, G)
  53. k: int
  54. Return:
  55. (B, H, W, k)
  56. """
  57. attn_map = attn_map.clone()
  58. for i in range(attn_map.size(0)):
  59. # [H*W, G]
  60. flatten_map = rearrange(attn_map[i], 'h w g -> (h w) g')
  61. kept_mat = torch.zeros(flatten_map.shape[0], device=flatten_map.device, dtype=torch.bool)
  62. area_per_group = flatten_map.sum(dim=0)
  63. top_group_idx = area_per_group.topk(k=k).indices.cpu().numpy().tolist()
  64. for group_idx in top_group_idx:
  65. kept_mat[flatten_map.argmax(dim=-1) == group_idx] = True
  66. # [H, W, 2]
  67. coords = torch.stack(
  68. torch.meshgrid(
  69. torch.arange(attn_map[i].shape[0], device=attn_map[i].device, dtype=attn_map[i].dtype),
  70. torch.arange(attn_map[i].shape[1], device=attn_map[i].device, dtype=attn_map[i].dtype)),
  71. dim=-1)
  72. coords = rearrange(coords, 'h w c -> (h w) c')
  73. # calculate distance between each pair of points
  74. # [non_kept, kept]
  75. dist_mat = torch.sum((coords[~kept_mat].unsqueeze(1) - coords[kept_mat].unsqueeze(0))**2, dim=-1)
  76. flatten_map[~kept_mat] = flatten_map[kept_mat.nonzero(as_tuple=True)[0][dist_mat.argmin(dim=-1)]]
  77. attn_map[i] = flatten_map.reshape_as(attn_map[i])
  78. return attn_map
  79. def seg2coord(seg_map):
  80. """
  81. Args:
  82. seg_map (np.ndarray): (H, W)
  83. Return:
  84. dict(group_id -> (x, y))
  85. """
  86. h, w = seg_map.shape
  87. # [h ,w, 2]
  88. coords = np.stack(np.meshgrid(np.arange(h), np.arange(w), indexing='ij'), axis=-1)
  89. labels = np.unique(seg_map)
  90. coord_map = {}
  91. for label in labels:
  92. coord_map[label] = coords[seg_map == label].mean(axis=0)
  93. return coord_map
  94. class GroupViTSegInference(EncoderDecoder):
  95. def __init__(self, model, text_embedding, with_bg, test_cfg=dict(mode='whole', bg_thresh=.95)):
  96. super(EncoderDecoder, self).__init__()
  97. if not isinstance(test_cfg, mmcv.Config):
  98. test_cfg = mmcv.Config(test_cfg)
  99. self.test_cfg = test_cfg
  100. self.model = model
  101. # [N, C]
  102. self.register_buffer('text_embedding', text_embedding)
  103. self.with_bg = with_bg
  104. self.bg_thresh = test_cfg['bg_thresh']
  105. if self.with_bg:
  106. self.num_classes = len(text_embedding) + 1
  107. else:
  108. self.num_classes = len(text_embedding)
  109. self.align_corners = False
  110. logger = get_logger()
  111. logger.info(
  112. f'Building GroupViTSegInference with {self.num_classes} classes, test_cfg={test_cfg}, with_bg={with_bg}')
  113. def forward_train(self, img, img_metas, gt_semantic_seg):
  114. raise NotImplementedError
  115. def get_attn_maps(self, img, return_onehot=False, rescale=False):
  116. """
  117. Args:
  118. img: [B, C, H, W]
  119. Returns:
  120. attn_maps: list[Tensor], attention map of shape [B, H, W, groups]
  121. """
  122. results = self.model.img_encoder(img, return_attn=True, as_dict=True)
  123. attn_maps = []
  124. with torch.no_grad():
  125. prev_attn_masks = None
  126. for idx, attn_dict in enumerate(results['attn_dicts']):
  127. if attn_dict is None:
  128. assert idx == len(results['attn_dicts']) - 1, 'only last layer can be None'
  129. continue
  130. # [B, G, HxW]
  131. # B: batch size (1), nH: number of heads, G: number of group token
  132. attn_masks = attn_dict['soft']
  133. # [B, nH, G, HxW] -> [B, nH, HxW, G]
  134. attn_masks = rearrange(attn_masks, 'b h g n -> b h n g')
  135. if prev_attn_masks is None:
  136. prev_attn_masks = attn_masks
  137. else:
  138. prev_attn_masks = prev_attn_masks @ attn_masks
  139. # [B, nH, HxW, G] -> [B, nH, H, W, G]
  140. attn_maps.append(resize_attn_map(prev_attn_masks, *img.shape[-2:]))
  141. for i in range(len(attn_maps)):
  142. attn_map = attn_maps[i]
  143. # [B, nh, H, W, G]
  144. assert attn_map.shape[1] == 1
  145. # [B, H, W, G]
  146. attn_map = attn_map.squeeze(1)
  147. if rescale:
  148. attn_map = rearrange(attn_map, 'b h w g -> b g h w')
  149. attn_map = F.interpolate(
  150. attn_map, size=img.shape[2:], mode='bilinear', align_corners=self.align_corners)
  151. attn_map = rearrange(attn_map, 'b g h w -> b h w g')
  152. if return_onehot:
  153. # [B, H, W, G]
  154. attn_map = F.one_hot(attn_map.argmax(dim=-1), num_classes=attn_map.shape[-1]).to(dtype=attn_map.dtype)
  155. attn_maps[i] = attn_map
  156. return attn_maps
  157. def encode_decode(self, img, img_metas):
  158. """Encode images with backbone and decode into a semantic segmentation
  159. map of the same size as input."""
  160. assert img.shape[0] == 1, 'batch size must be 1'
  161. # [B, C, H, W], get the last one only
  162. attn_map = self.get_attn_maps(img, rescale=True)[-1]
  163. # [H, W, G], select batch idx 0
  164. attn_map = attn_map[0]
  165. img_outs = self.model.encode_image(img, return_feat=True, as_dict=True)
  166. # [B, L, C] -> [L, C]
  167. grouped_img_tokens = img_outs['image_feat'].squeeze(0)
  168. img_avg_feat = img_outs['image_x']
  169. # [G, C]
  170. grouped_img_tokens = F.normalize(grouped_img_tokens, dim=-1)
  171. img_avg_feat = F.normalize(img_avg_feat, dim=-1)
  172. # [H, W, G]
  173. onehot_attn_map = F.one_hot(attn_map.argmax(dim=-1), num_classes=attn_map.shape[-1]).to(dtype=attn_map.dtype)
  174. num_fg_classes = self.text_embedding.shape[0]
  175. class_offset = 1 if self.with_bg else 0
  176. text_tokens = self.text_embedding
  177. num_classes = num_fg_classes + class_offset
  178. logit_scale = torch.clamp(self.model.logit_scale.exp(), max=100)
  179. # [G, N]
  180. group_affinity_mat = (grouped_img_tokens @ text_tokens.T) * logit_scale
  181. pre_group_affinity_mat = F.softmax(group_affinity_mat, dim=-1)
  182. avg_affinity_mat = (img_avg_feat @ text_tokens.T) * logit_scale
  183. avg_affinity_mat = F.softmax(avg_affinity_mat, dim=-1)
  184. affinity_mask = torch.zeros_like(avg_affinity_mat)
  185. avg_affinity_topk = avg_affinity_mat.topk(dim=-1, k=min(5, num_fg_classes))
  186. affinity_mask.scatter_add_(
  187. dim=-1, index=avg_affinity_topk.indices, src=torch.ones_like(avg_affinity_topk.values))
  188. group_affinity_mat.masked_fill_(~affinity_mask.bool(), float('-inf'))
  189. group_affinity_mat = F.softmax(group_affinity_mat, dim=-1)
  190. # TODO: check if necessary
  191. group_affinity_mat *= pre_group_affinity_mat
  192. pred_logits = torch.zeros(num_classes, *attn_map.shape[:2], device=img.device, dtype=img.dtype)
  193. pred_logits[class_offset:] = rearrange(onehot_attn_map @ group_affinity_mat, 'h w c -> c h w')
  194. if self.with_bg:
  195. bg_thresh = min(self.bg_thresh, group_affinity_mat.max().item())
  196. pred_logits[0, (onehot_attn_map @ group_affinity_mat).max(dim=-1).values < bg_thresh] = 1
  197. return pred_logits.unsqueeze(0)
  198. def blend_result(self, img, result, palette=None, out_file=None, opacity=0.5, with_bg=False):
  199. img = mmcv.imread(img)
  200. img = img.copy()
  201. seg = result[0]
  202. if palette is None:
  203. palette = self.PALETTE
  204. palette = np.array(palette)
  205. assert palette.shape[1] == 3, palette.shape
  206. assert len(palette.shape) == 2
  207. assert 0 < opacity <= 1.0
  208. color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
  209. for label, color in enumerate(palette):
  210. color_seg[seg == label, :] = color
  211. # convert to BGR
  212. color_seg = color_seg[..., ::-1]
  213. if with_bg:
  214. fg_mask = seg != 0
  215. img[fg_mask] = img[fg_mask] * (1 - opacity) + color_seg[fg_mask] * opacity
  216. else:
  217. img = img * (1 - opacity) + color_seg * opacity
  218. img = img.astype(np.uint8)
  219. if out_file is not None:
  220. mmcv.imwrite(img, out_file)
  221. return img
  222. def show_result(self, img_show, img_tensor, result, out_file, vis_mode='input'):
  223. assert vis_mode in [
  224. 'input', 'pred', 'input_pred', 'all_groups', 'first_group', 'final_group', 'input_pred_label'
  225. ], vis_mode
  226. if vis_mode == 'input':
  227. mmcv.imwrite(img_show, out_file)
  228. elif vis_mode == 'pred':
  229. output = Image.fromarray(result[0].astype(np.uint8)).convert('P')
  230. output.putpalette(np.array(self.PALETTE).astype(np.uint8))
  231. mmcv.mkdir_or_exist(osp.dirname(out_file))
  232. output.save(out_file.replace('.jpg', '.png'))
  233. elif vis_mode == 'input_pred':
  234. self.blend_result(img=img_show, result=result, out_file=out_file, opacity=0.5, with_bg=self.with_bg)
  235. elif vis_mode == 'input_pred_label':
  236. labels = np.unique(result[0])
  237. coord_map = seg2coord(result[0])
  238. # reference: https://github.com/open-mmlab/mmdetection/blob/ff9bc39913cb3ff5dde79d3933add7dc2561bab7/mmdet/models/detectors/base.py#L271 # noqa
  239. blended_img = self.blend_result(
  240. img=img_show, result=result, out_file=None, opacity=0.5, with_bg=self.with_bg)
  241. blended_img = mmcv.bgr2rgb(blended_img)
  242. width, height = img_show.shape[1], img_show.shape[0]
  243. EPS = 1e-2
  244. fig = plt.figure(frameon=False)
  245. canvas = fig.canvas
  246. dpi = fig.get_dpi()
  247. fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)
  248. # remove white edges by set subplot margin
  249. plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
  250. ax = plt.gca()
  251. ax.axis('off')
  252. for i, label in enumerate(labels):
  253. if self.with_bg and label == 0:
  254. continue
  255. center = coord_map[label].astype(np.int32)
  256. label_text = self.CLASSES[label]
  257. ax.text(
  258. center[1],
  259. center[0],
  260. f'{label_text}',
  261. bbox={
  262. 'facecolor': 'black',
  263. 'alpha': 0.5,
  264. 'pad': 0.7,
  265. 'edgecolor': 'none'
  266. },
  267. color='orangered',
  268. fontsize=16,
  269. verticalalignment='top',
  270. horizontalalignment='left')
  271. plt.imshow(blended_img)
  272. stream, _ = canvas.print_to_buffer()
  273. buffer = np.frombuffer(stream, dtype='uint8')
  274. img_rgba = buffer.reshape(height, width, 4)
  275. rgb, alpha = np.split(img_rgba, [3], axis=2)
  276. img = rgb.astype('uint8')
  277. img = mmcv.rgb2bgr(img)
  278. mmcv.imwrite(img, out_file)
  279. plt.close()
  280. elif vis_mode == 'all_groups' or vis_mode == 'final_group' or vis_mode == 'first_group':
  281. attn_map_list = self.get_attn_maps(img_tensor)
  282. assert len(attn_map_list) in [1, 2]
  283. # only show 16 groups for the first stage
  284. # if len(attn_map_list) == 2:
  285. # attn_map_list[0] = top_groups(attn_map_list[0], k=16)
  286. num_groups = [attn_map_list[layer_idx].shape[-1] for layer_idx in range(len(attn_map_list))]
  287. for layer_idx, attn_map in enumerate(attn_map_list):
  288. if vis_mode == 'first_group' and layer_idx != 0:
  289. continue
  290. if vis_mode == 'final_group' and layer_idx != len(attn_map_list) - 1:
  291. continue
  292. attn_map = rearrange(attn_map, 'b h w g -> b g h w')
  293. attn_map = F.interpolate(
  294. attn_map, size=img_show.shape[:2], mode='bilinear', align_corners=self.align_corners)
  295. group_result = attn_map.argmax(dim=1).cpu().numpy()
  296. if vis_mode == 'all_groups':
  297. layer_out_file = out_file.replace(
  298. osp.splitext(out_file)[-1], f'_layer{layer_idx}{osp.splitext(out_file)[-1]}')
  299. else:
  300. layer_out_file = out_file
  301. self.blend_result(
  302. img=img_show,
  303. result=group_result,
  304. palette=GROUP_PALETTE[sum(num_groups[:layer_idx]):sum(num_groups[:layer_idx + 1])],
  305. out_file=layer_out_file,
  306. opacity=0.5)
  307. else:
  308. raise ValueError(f'Unknown vis_type: {vis_mode}')