create_subset.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # -------------------------------------------------------------------------
  10. import os
  11. import os.path as osp
  12. import argparse
  13. import pandas as pd
  14. import sqlite3
  15. import pandas as pd
  16. import os.path as osp
  17. from urllib.parse import unquote
  18. import re
  19. from datadings.tools import locate_files
  20. from yfcc100m.vars import FILES
  21. from yfcc100m.convert_metadata import download_db
  22. from pandarallel import pandarallel
  23. pandarallel.initialize(progress_bar=True)
  24. def key2path(key):
  25. img_path = osp.join(key[0:3], key[3:6], key + '.jpg')
  26. return img_path
  27. def clean_caption(line):
  28. line = unquote(str(line))
  29. line = remove_html_tags(line)
  30. return line.replace('\n', ' ').replace('+', ' ')
  31. def remove_html_tags(text):
  32. """Remove html tags from a string"""
  33. clean = re.compile('<.*?>')
  34. return re.sub(clean, '', text)
  35. def parse_args():
  36. parser = argparse.ArgumentParser(description='Create YFCC subset sql db and tsv')
  37. parser.add_argument('--input-dir', help='input sql db file directory')
  38. parser.add_argument('--output-dir', help='output tsv directory')
  39. parser.add_argument(
  40. '--subset', help='subset of data to use', default='yfcc100m_subset_data.tsv')
  41. args = parser.parse_args()
  42. return args
  43. def main():
  44. args = parse_args()
  45. files = locate_files(FILES, args.input_dir)
  46. # download DB file with AWS tools
  47. download_db(files)
  48. fullset_name = 'yfcc100m_dataset'
  49. subset_name = 'yfcc14m_dataset'
  50. conn = sqlite3.connect(osp.join(args.input_dir, 'yfcc100m_dataset.sql'))
  51. # get column names
  52. # some settings that hopefully speed up the queries
  53. # conn.execute(f'PRAGMA query_only = YES')
  54. conn.execute(f'PRAGMA journal_mode = OFF')
  55. conn.execute(f'PRAGMA locking_mode = EXCLUSIVE')
  56. conn.execute(f'PRAGMA page_size = 4096')
  57. conn.execute(f'PRAGMA mmap_size = {4*1024*1024}')
  58. conn.execute(f'PRAGMA cache_size = 10000')
  59. print('reading subset data')
  60. subset_df = pd.read_csv(args.subset, sep='\t', usecols=[1, 2], names=['photoid', 'photo_hash'], index_col='photoid')
  61. subset_df.to_sql(subset_name, con=conn, if_exists='replace')
  62. print('overwriting with subset')
  63. select_query = f'select {fullset_name}.*, {subset_name}.photo_hash from {fullset_name} inner join {subset_name} on {fullset_name}.photoid = {subset_name}.photoid'
  64. new_name = 'yfcc100m_dataset_new'
  65. print('creating new table')
  66. conn.execute(f'drop table if exists {new_name}')
  67. conn.execute(' '.join([f'create table {new_name} as ', select_query]))
  68. print(f'droping {fullset_name}')
  69. conn.execute(f'drop table if exists {fullset_name}')
  70. print(f'droping {subset_name}')
  71. conn.execute(f'drop table if exists {subset_name}')
  72. print(f'renaming {new_name} to {fullset_name}')
  73. conn.execute(f'alter table {new_name} rename to {fullset_name}')
  74. print('vacuuming db')
  75. conn.execute('vacuum')
  76. print(f'Loading dataframe from SQL')
  77. anno_df = pd.read_sql(f'select * from {fullset_name}', con=conn)
  78. print(f'Loaded dataframe from SQL: \n{anno_df.head()}')
  79. print(f'Length: \n{len(anno_df)}')
  80. print(f'generating filepath')
  81. anno_df['file'] = anno_df['photo_hash'].parallel_map(key2path)
  82. anno_df['caption'] = anno_df['description'].parallel_map(clean_caption)
  83. anno_df = anno_df[['file', 'caption']]
  84. print(f'Generated dataframe: \n{anno_df.head()}')
  85. print('saving subset as tsv')
  86. os.makedirs(args.output_dir, exist_ok=True)
  87. anno_df.to_csv(osp.join(args.output_dir, 'yfcc14m_dataset.tsv'), sep='\t', index=False)
  88. conn.close()
  89. if __name__ == '__main__':
  90. main()