convert_coco_object.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # -------------------------------------------------------------------------
  10. import argparse
  11. import os.path as osp
  12. import shutil
  13. from functools import partial
  14. from glob import glob
  15. import mmcv
  16. import numpy as np
  17. from PIL import Image
  18. COCO_LEN = 123287
  19. clsID_to_trID = {
  20. 0: 0,
  21. 1: 1,
  22. 2: 2,
  23. 3: 3,
  24. 4: 4,
  25. 5: 5,
  26. 6: 6,
  27. 7: 7,
  28. 8: 8,
  29. 9: 9,
  30. 10: 10,
  31. 12: 11,
  32. 13: 12,
  33. 14: 13,
  34. 15: 14,
  35. 16: 15,
  36. 17: 16,
  37. 18: 17,
  38. 19: 18,
  39. 20: 19,
  40. 21: 20,
  41. 22: 21,
  42. 23: 22,
  43. 24: 23,
  44. 26: 24,
  45. 27: 25,
  46. 30: 26,
  47. 31: 27,
  48. 32: 28,
  49. 33: 29,
  50. 34: 30,
  51. 35: 31,
  52. 36: 32,
  53. 37: 33,
  54. 38: 34,
  55. 39: 35,
  56. 40: 36,
  57. 41: 37,
  58. 42: 38,
  59. 43: 39,
  60. 45: 40,
  61. 46: 41,
  62. 47: 42,
  63. 48: 43,
  64. 49: 44,
  65. 50: 45,
  66. 51: 46,
  67. 52: 47,
  68. 53: 48,
  69. 54: 49,
  70. 55: 50,
  71. 56: 51,
  72. 57: 52,
  73. 58: 53,
  74. 59: 54,
  75. 60: 55,
  76. 61: 56,
  77. 62: 57,
  78. 63: 58,
  79. 64: 59,
  80. 66: 60,
  81. 69: 61,
  82. 71: 62,
  83. 72: 63,
  84. 73: 64,
  85. 74: 65,
  86. 75: 66,
  87. 76: 67,
  88. 77: 68,
  89. 78: 69,
  90. 79: 70,
  91. 80: 71,
  92. 81: 72,
  93. 83: 73,
  94. 84: 74,
  95. 85: 75,
  96. 86: 76,
  97. 87: 77,
  98. 88: 78,
  99. 89: 79,
  100. 91: 80,
  101. 92: 81,
  102. 93: 82,
  103. 94: 83,
  104. 95: 84,
  105. 96: 85,
  106. 97: 86,
  107. 98: 87,
  108. 99: 88,
  109. 100: 89,
  110. 101: 90,
  111. 102: 91,
  112. 103: 92,
  113. 104: 93,
  114. 105: 94,
  115. 106: 95,
  116. 107: 96,
  117. 108: 97,
  118. 109: 98,
  119. 110: 99,
  120. 111: 100,
  121. 112: 101,
  122. 113: 102,
  123. 114: 103,
  124. 115: 104,
  125. 116: 105,
  126. 117: 106,
  127. 118: 107,
  128. 119: 108,
  129. 120: 109,
  130. 121: 110,
  131. 122: 111,
  132. 123: 112,
  133. 124: 113,
  134. 125: 114,
  135. 126: 115,
  136. 127: 116,
  137. 128: 117,
  138. 129: 118,
  139. 130: 119,
  140. 131: 120,
  141. 132: 121,
  142. 133: 122,
  143. 134: 123,
  144. 135: 124,
  145. 136: 125,
  146. 137: 126,
  147. 138: 127,
  148. 139: 128,
  149. 140: 129,
  150. 141: 130,
  151. 142: 131,
  152. 143: 132,
  153. 144: 133,
  154. 145: 134,
  155. 146: 135,
  156. 147: 136,
  157. 148: 137,
  158. 149: 138,
  159. 150: 139,
  160. 151: 140,
  161. 152: 141,
  162. 153: 142,
  163. 154: 143,
  164. 155: 144,
  165. 156: 145,
  166. 157: 146,
  167. 158: 147,
  168. 159: 148,
  169. 160: 149,
  170. 161: 150,
  171. 162: 151,
  172. 163: 152,
  173. 164: 153,
  174. 165: 154,
  175. 166: 155,
  176. 167: 156,
  177. 168: 157,
  178. 169: 158,
  179. 170: 159,
  180. 171: 160,
  181. 172: 161,
  182. 173: 162,
  183. 174: 163,
  184. 175: 164,
  185. 176: 165,
  186. 177: 166,
  187. 178: 167,
  188. 179: 168,
  189. 180: 169,
  190. 181: 170,
  191. 255: 255
  192. }
  193. # set to background
  194. for k, v in clsID_to_trID.items():
  195. clsID_to_trID[k] = v + 1
  196. if k > 90:
  197. clsID_to_trID[k] = 0
  198. def convert_to_trainID(maskpath, out_mask_dir, is_train):
  199. mask = np.array(Image.open(maskpath))
  200. mask_copy = mask.copy()
  201. for clsID, trID in clsID_to_trID.items():
  202. mask_copy[mask == clsID] = trID
  203. seg_filename = osp.join(
  204. out_mask_dir, 'train2017',
  205. osp.basename(maskpath).split('.')[0] +
  206. '_instanceTrainIds.png') if is_train else osp.join(
  207. out_mask_dir, 'val2017',
  208. osp.basename(maskpath).split('.')[0] + '_instanceTrainIds.png')
  209. Image.fromarray(mask_copy).save(seg_filename, 'PNG')
  210. def parse_args():
  211. parser = argparse.ArgumentParser(
  212. description=\
  213. 'Convert COCO Stuff 164k annotations to COCO Objects') # noqa
  214. parser.add_argument('coco_path', help='coco stuff path')
  215. parser.add_argument('-o', '--out_dir', help='output path')
  216. parser.add_argument(
  217. '--nproc', default=16, type=int, help='number of process')
  218. args = parser.parse_args()
  219. return args
  220. def main():
  221. args = parse_args()
  222. coco_path = args.coco_path
  223. nproc = args.nproc
  224. out_dir = args.out_dir or coco_path
  225. out_img_dir = osp.join(out_dir, 'images')
  226. out_mask_dir = osp.join(out_dir, 'annotations')
  227. mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017'))
  228. mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017'))
  229. if out_dir != coco_path:
  230. shutil.copytree(osp.join(coco_path, 'images'), out_img_dir)
  231. train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png'))
  232. train_list = [file for file in train_list if 'TrainIds' not in file]
  233. test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png'))
  234. test_list = [file for file in test_list if 'TrainIds' not in file]
  235. assert (len(train_list) +
  236. len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
  237. len(train_list), len(test_list))
  238. if args.nproc > 1:
  239. mmcv.track_parallel_progress(
  240. partial(
  241. convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
  242. train_list,
  243. nproc=nproc)
  244. mmcv.track_parallel_progress(
  245. partial(
  246. convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
  247. test_list,
  248. nproc=nproc)
  249. else:
  250. mmcv.track_progress(
  251. partial(
  252. convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
  253. train_list)
  254. mmcv.track_progress(
  255. partial(
  256. convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
  257. test_list)
  258. print('Done!')
  259. if __name__ == '__main__':
  260. main()