# ------------------------------------------------------------------------- # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved. # # This work is made available under the Nvidia Source Code License. # To view a copy of this license, visit # https://github.com/NVlabs/GroupViT/blob/main/LICENSE # # Written by Jiarui Xu # ------------------------------------------------------------------------- import argparse import os.path as osp import shutil from functools import partial from glob import glob import mmcv import numpy as np from PIL import Image COCO_LEN = 123287 clsID_to_trID = { 0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 12: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 26: 24, 27: 25, 30: 26, 31: 27, 32: 28, 33: 29, 34: 30, 35: 31, 36: 32, 37: 33, 38: 34, 39: 35, 40: 36, 41: 37, 42: 38, 43: 39, 45: 40, 46: 41, 47: 42, 48: 43, 49: 44, 50: 45, 51: 46, 52: 47, 53: 48, 54: 49, 55: 50, 56: 51, 57: 52, 58: 53, 59: 54, 60: 55, 61: 56, 62: 57, 63: 58, 64: 59, 66: 60, 69: 61, 71: 62, 72: 63, 73: 64, 74: 65, 75: 66, 76: 67, 77: 68, 78: 69, 79: 70, 80: 71, 81: 72, 83: 73, 84: 74, 85: 75, 86: 76, 87: 77, 88: 78, 89: 79, 91: 80, 92: 81, 93: 82, 94: 83, 95: 84, 96: 85, 97: 86, 98: 87, 99: 88, 100: 89, 101: 90, 102: 91, 103: 92, 104: 93, 105: 94, 106: 95, 107: 96, 108: 97, 109: 98, 110: 99, 111: 100, 112: 101, 113: 102, 114: 103, 115: 104, 116: 105, 117: 106, 118: 107, 119: 108, 120: 109, 121: 110, 122: 111, 123: 112, 124: 113, 125: 114, 126: 115, 127: 116, 128: 117, 129: 118, 130: 119, 131: 120, 132: 121, 133: 122, 134: 123, 135: 124, 136: 125, 137: 126, 138: 127, 139: 128, 140: 129, 141: 130, 142: 131, 143: 132, 144: 133, 145: 134, 146: 135, 147: 136, 148: 137, 149: 138, 150: 139, 151: 140, 152: 141, 153: 142, 154: 143, 155: 144, 156: 145, 157: 146, 158: 147, 159: 148, 160: 149, 161: 150, 162: 151, 163: 152, 164: 153, 165: 154, 166: 155, 167: 156, 168: 157, 169: 158, 170: 159, 171: 160, 172: 161, 173: 162, 174: 163, 175: 164, 176: 165, 177: 166, 178: 167, 179: 168, 180: 169, 181: 170, 255: 255 } # set to background for k, v in clsID_to_trID.items(): clsID_to_trID[k] = v + 1 if k > 90: clsID_to_trID[k] = 0 def convert_to_trainID(maskpath, out_mask_dir, is_train): mask = np.array(Image.open(maskpath)) mask_copy = mask.copy() for clsID, trID in clsID_to_trID.items(): mask_copy[mask == clsID] = trID seg_filename = osp.join( out_mask_dir, 'train2017', osp.basename(maskpath).split('.')[0] + '_instanceTrainIds.png') if is_train else osp.join( out_mask_dir, 'val2017', osp.basename(maskpath).split('.')[0] + '_instanceTrainIds.png') Image.fromarray(mask_copy).save(seg_filename, 'PNG') def parse_args(): parser = argparse.ArgumentParser( description=\ 'Convert COCO Stuff 164k annotations to COCO Objects') # noqa parser.add_argument('coco_path', help='coco stuff path') parser.add_argument('-o', '--out_dir', help='output path') parser.add_argument( '--nproc', default=16, type=int, help='number of process') args = parser.parse_args() return args def main(): args = parse_args() coco_path = args.coco_path nproc = args.nproc out_dir = args.out_dir or coco_path out_img_dir = osp.join(out_dir, 'images') out_mask_dir = osp.join(out_dir, 'annotations') mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017')) mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017')) if out_dir != coco_path: shutil.copytree(osp.join(coco_path, 'images'), out_img_dir) train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png')) train_list = [file for file in train_list if 'TrainIds' not in file] test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png')) test_list = [file for file in test_list if 'TrainIds' not in file] assert (len(train_list) + len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( len(train_list), len(test_list)) if args.nproc > 1: mmcv.track_parallel_progress( partial( convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), train_list, nproc=nproc) mmcv.track_parallel_progress( partial( convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), test_list, nproc=nproc) else: mmcv.track_progress( partial( convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), train_list) mmcv.track_progress( partial( convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), test_list) print('Done!') if __name__ == '__main__': main()