process_redcaps.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. import argparse
  14. import json
  15. import os
  16. import os.path as osp
  17. import random
  18. import pandas as pd
  19. import pyarrow as pa
  20. import pyarrow.parquet as pq
  21. import tqdm
  22. def get_args_parser():
  23. parser = argparse.ArgumentParser()
  24. parser.add_argument(
  25. 'input', type=str, help='path to redcaps annotations directory')
  26. parser.add_argument(
  27. 'output', type=str, help='output annotations file path')
  28. parser.add_argument(
  29. '--num-split', type=int, help='number of splits to make')
  30. return parser
  31. def main(args):
  32. annos = []
  33. for fname in tqdm.tqdm(os.listdir(args.input), desc='merging json files'):
  34. if fname.endswith('json'):
  35. with open(os.path.join(args.input, fname)) as f:
  36. a = json.load(f)
  37. for d in a['annotations']:
  38. cur_d = {'URL': d['url'], 'TEXT': d['caption']}
  39. annos.append(cur_d)
  40. random.seed(42)
  41. random.shuffle(annos)
  42. if args.num_split is None:
  43. df = pd.DataFrame(annos)
  44. print(df.head())
  45. print(f'saving {len(df)} annotations to {args.output}')
  46. table = pa.Table.from_pandas(df)
  47. os.makedirs(osp.dirname(args.output), exist_ok=True)
  48. pq.write_table(table, args.output)
  49. else:
  50. for i in range(args.num_split):
  51. df = pd.DataFrame(annos[i::args.num_split])
  52. print(df.head())
  53. output = osp.splitext(
  54. args.output)[0] + f'_part{i}{osp.splitext(args.output)[1]}'
  55. print(f'saving {len(df)} annotations to {output}')
  56. table = pa.Table.from_pandas(df)
  57. os.makedirs(osp.dirname(output), exist_ok=True)
  58. pq.write_table(table, output)
  59. if __name__ == '__main__':
  60. parser = get_args_parser()
  61. args = parser.parse_args()
  62. main(args)