data_process_cc12m.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import json
  2. import numpy as np
  3. import nltk
  4. import argparse
  5. import os
  6. import pandas as pd
  7. from ipdb import set_trace
  8. import subprocess
  9. import random
  10. from multiprocessing import Pool
  11. ### You can modify this as you want ###
  12. TOP_CLASSES_1=[
  13. 'people', 'man', 'men', 'woman', 'women', 'girl', 'boy', 'lady', 'kid', 'child', 'children', 'baby', 'student', 'bride', 'groom', 'couple', 'prince', 'princess', \
  14. 'car', 'bus', 'truck', 'motorcycle', 'train', 'bicycle', 'boat', 'aeroplane', 'airplane', 'motorbike', 'bike',\
  15. 'cup', 'bottle', 'bowl', 'knife', 'spoon', 'glass', 'fork',\
  16. 'chair', 'table', 'bench', 'clock', 'laptop', 'light', 'vase', 'plant', 'remote', 'microwave', 'toaster', 'oven','mouse', 'keyboard','sofa', 'monitor','desk', 'tv','TV', 'couch', 'flower','refrigerator', \
  17. 'house', 'building', 'hotel',\
  18. 'handbag', 'umbrella','book', 'backpack', 'phone', 'shirt', 'tie', 'suitcase','T-shirt', 'bag', 'box', \
  19. 'sink','bed','toilet',\
  20. 'cat','dog', 'horse', 'bird','cow', 'sheep' ,'elephant', 'bear', 'zebra', 'giraffe', \
  21. 'ball', 'racket', 'skateboard', 'skis', 'snowboard', 'surfboard', 'kite', \
  22. 'pizza', 'cake', 'apple', 'banana', 'sandwich', 'orange', 'carrot', 'donut' ,\
  23. ]
  24. def judge_noun(word):
  25. if word in TOP_CLASSES_1:
  26. return 1
  27. return 0
  28. def make_filter(infos):
  29. args, cur_index, subset_list = infos[0], infos[1], infos[2]
  30. new_dataframe = pd.DataFrame()
  31. print(f'Begin processing {cur_index}')
  32. for i, item in enumerate(subset_list.iterrows()):
  33. each_cap = item[1]['caption']
  34. all_words = nltk.word_tokenize(each_cap)
  35. valid_list = [judge_noun(word) for word in all_words]
  36. valid = sum(valid_list)
  37. if valid:
  38. valid_words = np.array(all_words)[np.argwhere(valid_list)][:,0].tolist()
  39. valid_words = list(set(valid_words)) ## keep unique entities
  40. item[1]['entity'] = ','.join(valid_words)
  41. new_dataframe = new_dataframe._append(item[1])
  42. print('Filtered {} out of {}'.format(len(new_dataframe), len(subset_list)))
  43. new_dataframe.to_csv(f'{args.dstdir}/cc12m_{cur_index}.csv', index=False)
  44. return
  45. def filter_subset_with_entities(args):
  46. # all_captions = pd.read_csv(f'{args.srcdir}/cc12m.tsv', sep='\t')
  47. if args.srcdir.endswith('.tsv'):
  48. all_captions = pd.read_csv(args.srcdir, sep='\t')
  49. elif args.srcdir.endswith('.csv'):
  50. all_captions = pd.read_csv(args.srcdir)
  51. all_captions.columns = ['image_id', 'caption']
  52. total_len = len(all_captions)
  53. all_ranges = np.linspace(0.0, 1.0, args.processors + 1)
  54. chunk_list = []
  55. for i in range(args.processors):
  56. begin = int(all_ranges[i] * total_len)
  57. ender = int(all_ranges[i + 1] * total_len)
  58. subset_list = all_captions[begin:ender]
  59. chunk_list.append([args, i, subset_list])
  60. print(f'Begin filtering with {args.processors}')
  61. pool = Pool(args.processors)
  62. pool.map(make_filter, tuple(chunk_list))
  63. def merge_all_subset(args):
  64. all_files = os.listdir(args.dstdir)
  65. all_files = [f for f in all_files if f.endswith('.csv')]
  66. all_files = sorted(all_files)
  67. all_data = pd.DataFrame()
  68. for f in all_files:
  69. each_data = pd.read_csv(os.path.join(args.dstdir, f))
  70. all_data = all_data._append(each_data)
  71. if args.remove_subfiles:
  72. print('Removing sub-files')
  73. for f in all_files:
  74. cmd = f'rm -f {os.path.join(args.dstdir, f)}'
  75. print(cmd)
  76. os.system(cmd)
  77. all_data.to_csv(os.path.join(args.dstdir, 'cc12m_filtered_subset.csv'), index=False)
  78. def construct_crossimage(args):
  79. '''
  80. For each image, we randomly sample K (e.g. 10) images that contain shared entity.
  81. '''
  82. metafile = pd.read_csv(args.metafile)
  83. all_entites = metafile['entity'].tolist()
  84. entity_dict = {}
  85. for i, each_entity in enumerate(all_entites):
  86. each_entity = each_entity.split(',')
  87. # print(i, each_entity)
  88. for sub_entity in each_entity:
  89. if sub_entity not in entity_dict:
  90. entity_dict[sub_entity] = []
  91. entity_dict[sub_entity].append(i)
  92. print('Done calculating entity dict')
  93. for k, v in entity_dict.items():
  94. print(k, len(v))
  95. ### assign entity ###
  96. topK = 10
  97. all_pairs = []
  98. all_paired_entity = []
  99. print(f'Begin sampling {topK} pairs for each element')
  100. for i, each_entity in enumerate(all_entites):
  101. each_entity = each_entity.split(',')
  102. sampled_entity = np.random.choice(each_entity, size=topK, replace=True)
  103. sampled_pair = [random.choice(entity_dict[x]) for x in sampled_entity]
  104. all_pairs.append(sampled_pair)
  105. all_paired_entity.append(sampled_entity.tolist())
  106. assert len(all_pairs) == len(all_entites) == len(all_paired_entity)
  107. metafile['pairindex'] = all_pairs
  108. metafile['pairentity'] = all_paired_entity
  109. metafile.to_csv(args.metafile.replace('.csv', '_pair.csv'), index=False)
  110. print('Done constructing pairs')
  111. def convert_json_to_csv(args):
  112. jsonfile = args.metafile
  113. df = pd.DataFrame()
  114. print('Start converting')
  115. with open(jsonfile) as f:
  116. lines = f.readlines()
  117. for line in lines:
  118. info = json.loads(line)
  119. df = df._append(pd.Series(info), ignore_index=True)
  120. outdir = args.metafile.replace('.json','.csv')
  121. df.to_csv(outdir)
  122. print(f'Done converting to {outdir}')
  123. return
  124. if __name__ == '__main__':
  125. parser = argparse.ArgumentParser()
  126. parser.add_argument('--processors', type =int, default=8, help='processors for data filtering')
  127. parser.add_argument('--srcdir', type=str, default=None, help='source dir that contains the original cc12m metafile')
  128. parser.add_argument('--dstdir', type=str, default=None, help='target dir to save the filtered subset')
  129. parser.add_argument('--mode', type=str, default='filter', help='choices: [filter, merge, makepair]')
  130. parser.add_argument('--metafile', type=str, default=None, help='the metafile used for constructing cross-image pairs')
  131. parser.add_argument('--remove_subfiles', type=bool, default=False, help='whether to remove the generated sub-files')
  132. args = parser.parse_args()
  133. if args.mode == 'filter':
  134. assert args.srcdir is not None, 'Please specify the source dir containing the cc12m metafile'
  135. if args.dstdir is None:
  136. args.dstdir = f'{"/".join(args.srcdir.split("/")[:-1])}/subsets'
  137. print(f'Target dir not specified, use {args.dstdir}')
  138. os.makedirs(args.dstdir, exist_ok=True)
  139. filter_subset_with_entities(args)
  140. elif args.mode == 'merge':
  141. assert args.dstdir is not None, 'Please specify the target dir containing the filtered metafiles'
  142. merge_all_subset(args)
  143. elif args.mode == 'makepair':
  144. assert args.metafile is not None, 'Please specify the metafile for constructing the cross-image relation'
  145. construct_crossimage(args)
  146. elif args.mode == 'json2csv':
  147. assert args.metafile is not None, 'Please specify the metafile for converting'
  148. convert_json_to_csv(args)
  149. else:
  150. raise NotImplementedError