convert_yfcc14m.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. import argparse
  14. import json
  15. import os
  16. import os.path as osp
  17. import random
  18. import sys
  19. import zipfile
  20. import numpy as np
  21. import pandas as pd
  22. import webdataset as wds
  23. from tqdm import tqdm
  24. import mmcv
  25. def write_dataset(args):
  26. df = pd.read_csv(
  27. args.info, sep='\t', index_col='file', dtype=str, lineterminator='\n')
  28. print(f'Loaded dataframe: \n{df}')
  29. print(f'Length: \n{len(df)}')
  30. # This is the output pattern under which we write shards.
  31. pattern = os.path.join(args.shards, f'yfcc14m-%06d.tar')
  32. with wds.ShardWriter(
  33. pattern, maxsize=int(args.maxsize),
  34. maxcount=int(args.maxcount)) as sink:
  35. sink.verbose = 0
  36. all_keys = set()
  37. skipped = 0
  38. zip_files = list(mmcv.scandir(args.root, suffix='zip'))
  39. for idx, file in tqdm(
  40. enumerate(zip_files), desc='total', total=len(zip_files)):
  41. with zipfile.ZipFile(osp.join(args.root, file), 'r') as zfile:
  42. filename_list = zfile.namelist()
  43. for filename in tqdm(
  44. filename_list, position=1, desc=f'{file}', leave=None):
  45. image = zfile.read(filename)
  46. if image is None:
  47. skipped += 1
  48. tqdm.write(f'Skipping {filename}, {skipped}/{len(df)}')
  49. continue
  50. fname = filename.replace('data/images/', '')
  51. # Construct a unique key from the filename.
  52. key = os.path.splitext(os.path.basename(fname))[0]
  53. # Useful check.
  54. if key in all_keys:
  55. tqdm.write(f'duplicate: {fname}')
  56. continue
  57. assert key not in all_keys
  58. all_keys.add(key)
  59. text = str(df.loc[fname]['caption'])
  60. if len(text.split(' ')) < 2:
  61. skipped += 1
  62. tqdm.write(f'Text {text} too short')
  63. tqdm.write(f'Skipping {fname}, {skipped}/{len(df)}')
  64. continue
  65. # Construct a sample.
  66. xkey = key
  67. sample = {'__key__': xkey, 'jpg': image, 'text': text}
  68. # Write the sample to the sharded tar archives.
  69. sink.write(sample)
  70. print(f'skipped: {skipped}/{len(df)}')
  71. print(f'total keys: {len(all_keys)}')
  72. def parse_args():
  73. parser = argparse.ArgumentParser(
  74. """Generate sharded dataset from original ImageNet data.""")
  75. parser.add_argument('--maxsize', type=float, default=1e9)
  76. parser.add_argument('--maxcount', type=float, default=100000)
  77. parser.add_argument('--shards', help='directory where shards are written')
  78. parser.add_argument('--root', help='data root path')
  79. parser.add_argument('--info', help='tsv path')
  80. args = parser.parse_args()
  81. assert args.maxsize > 10000000
  82. assert args.maxcount < 1000000
  83. return args
  84. def main():
  85. args = parse_args()
  86. seed = 0
  87. random.seed(seed)
  88. np.random.seed(seed)
  89. os.environ['PYTHONHASHSEED'] = str(seed)
  90. if not os.path.isdir(os.path.join(args.shards, '.')):
  91. print(
  92. f'{args.shards}: should be a writable destination directory for shards',
  93. file=sys.stderr)
  94. sys.exit(1)
  95. write_dataset(args=args)
  96. if __name__ == '__main__':
  97. main()