process_redcaps.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # -------------------------------------------------------------------------
  10. import argparse
  11. import json
  12. import os
  13. import os.path as osp
  14. import random
  15. import pandas as pd
  16. import pyarrow as pa
  17. import pyarrow.parquet as pq
  18. import tqdm
  19. def get_args_parser():
  20. parser = argparse.ArgumentParser()
  21. parser.add_argument(
  22. 'input', type=str, help='path to redcaps annotations directory')
  23. parser.add_argument(
  24. 'output', type=str, help='output annotations file path')
  25. parser.add_argument(
  26. '--num-split', type=int, help='number of splits to make')
  27. return parser
  28. def main(args):
  29. annos = []
  30. for fname in tqdm.tqdm(os.listdir(args.input), desc='merging json files'):
  31. if fname.endswith('json'):
  32. with open(os.path.join(args.input, fname)) as f:
  33. a = json.load(f)
  34. for d in a['annotations']:
  35. cur_d = {'URL': d['url'], 'TEXT': d['caption']}
  36. annos.append(cur_d)
  37. random.seed(42)
  38. random.shuffle(annos)
  39. if args.num_split is None:
  40. df = pd.DataFrame(annos)
  41. print(df.head())
  42. print(f'saving {len(df)} annotations to {args.output}')
  43. table = pa.Table.from_pandas(df)
  44. os.makedirs(osp.dirname(args.output), exist_ok=True)
  45. pq.write_table(table, args.output)
  46. else:
  47. for i in range(args.num_split):
  48. df = pd.DataFrame(annos[i::args.num_split])
  49. print(df.head())
  50. output = osp.splitext(
  51. args.output)[0] + f'_part{i}{osp.splitext(args.output)[1]}'
  52. print(f'saving {len(df)} annotations to {output}')
  53. table = pa.Table.from_pandas(df)
  54. os.makedirs(osp.dirname(output), exist_ok=True)
  55. pq.write_table(table, output)
  56. if __name__ == '__main__':
  57. parser = get_args_parser()
  58. args = parser.parse_args()
  59. main(args)