convert_yfcc14m.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # -------------------------------------------------------------------------
  10. import argparse
  11. import json
  12. import os
  13. import os.path as osp
  14. import random
  15. import sys
  16. import zipfile
  17. import numpy as np
  18. import pandas as pd
  19. import webdataset as wds
  20. from tqdm import tqdm
  21. import mmcv
  22. def write_dataset(args):
  23. df = pd.read_csv(
  24. args.info, sep='\t', index_col='file', dtype=str, lineterminator='\n')
  25. print(f'Loaded dataframe: \n{df}')
  26. print(f'Length: \n{len(df)}')
  27. # This is the output pattern under which we write shards.
  28. pattern = os.path.join(args.shards, f'yfcc14m-%06d.tar')
  29. with wds.ShardWriter(
  30. pattern, maxsize=int(args.maxsize),
  31. maxcount=int(args.maxcount)) as sink:
  32. sink.verbose = 0
  33. all_keys = set()
  34. skipped = 0
  35. zip_files = list(mmcv.scandir(args.root, suffix='zip'))
  36. for idx, file in tqdm(
  37. enumerate(zip_files), desc='total', total=len(zip_files)):
  38. with zipfile.ZipFile(osp.join(args.root, file), 'r') as zfile:
  39. filename_list = zfile.namelist()
  40. for filename in tqdm(
  41. filename_list, position=1, desc=f'{file}', leave=None):
  42. image = zfile.read(filename)
  43. if image is None:
  44. skipped += 1
  45. tqdm.write(f'Skipping {filename}, {skipped}/{len(df)}')
  46. continue
  47. fname = filename.replace('data/images/', '')
  48. # Construct a unique key from the filename.
  49. key = os.path.splitext(os.path.basename(fname))[0]
  50. # Useful check.
  51. if key in all_keys:
  52. tqdm.write(f'duplicate: {fname}')
  53. continue
  54. assert key not in all_keys
  55. all_keys.add(key)
  56. text = str(df.loc[fname]['caption'])
  57. if len(text.split(' ')) < 2:
  58. skipped += 1
  59. tqdm.write(f'Text {text} too short')
  60. tqdm.write(f'Skipping {fname}, {skipped}/{len(df)}')
  61. continue
  62. # Construct a sample.
  63. xkey = key
  64. sample = {'__key__': xkey, 'jpg': image, 'text': text}
  65. # Write the sample to the sharded tar archives.
  66. sink.write(sample)
  67. print(f'skipped: {skipped}/{len(df)}')
  68. print(f'total keys: {len(all_keys)}')
  69. def parse_args():
  70. parser = argparse.ArgumentParser(
  71. """Generate sharded dataset from original ImageNet data.""")
  72. parser.add_argument('--maxsize', type=float, default=1e9)
  73. parser.add_argument('--maxcount', type=float, default=100000)
  74. parser.add_argument('--shards', help='directory where shards are written')
  75. parser.add_argument('--root', help='data root path')
  76. parser.add_argument('--info', help='tsv path')
  77. args = parser.parse_args()
  78. assert args.maxsize > 10000000
  79. assert args.maxcount < 1000000
  80. return args
  81. def main():
  82. args = parse_args()
  83. seed = 0
  84. random.seed(seed)
  85. np.random.seed(seed)
  86. os.environ['PYTHONHASHSEED'] = str(seed)
  87. if not os.path.isdir(os.path.join(args.shards, '.')):
  88. print(
  89. f'{args.shards}: should be a writable destination directory for shards',
  90. file=sys.stderr)
  91. sys.exit(1)
  92. write_dataset(args=args)
  93. if __name__ == '__main__':
  94. main()