create_subset.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. import os
  14. import os.path as osp
  15. import argparse
  16. import pandas as pd
  17. import sqlite3
  18. import pandas as pd
  19. import os.path as osp
  20. from urllib.parse import unquote
  21. import re
  22. from datadings.tools import locate_files
  23. from yfcc100m.vars import FILES
  24. from yfcc100m.convert_metadata import download_db
  25. from pandarallel import pandarallel
  26. pandarallel.initialize(progress_bar=True)
  27. def key2path(key):
  28. img_path = osp.join(key[0:3], key[3:6], key + '.jpg')
  29. return img_path
  30. def clean_caption(line):
  31. line = unquote(str(line))
  32. line = remove_html_tags(line)
  33. return line.replace('\n', ' ').replace('+', ' ')
  34. def remove_html_tags(text):
  35. """Remove html tags from a string"""
  36. clean = re.compile('<.*?>')
  37. return re.sub(clean, '', text)
  38. def parse_args():
  39. parser = argparse.ArgumentParser(description='Create YFCC subset sql db and tsv')
  40. parser.add_argument('--input-dir', help='input sql db file directory')
  41. parser.add_argument('--output-dir', help='output tsv directory')
  42. parser.add_argument(
  43. '--subset', help='subset of data to use', default='yfcc100m_subset_data.tsv')
  44. args = parser.parse_args()
  45. return args
  46. def main():
  47. args = parse_args()
  48. files = locate_files(FILES, args.input_dir)
  49. # download DB file with AWS tools
  50. download_db(files)
  51. fullset_name = 'yfcc100m_dataset'
  52. subset_name = 'yfcc14m_dataset'
  53. conn = sqlite3.connect(osp.join(args.input_dir, 'yfcc100m_dataset.sql'))
  54. # get column names
  55. # some settings that hopefully speed up the queries
  56. # conn.execute(f'PRAGMA query_only = YES')
  57. conn.execute(f'PRAGMA journal_mode = OFF')
  58. conn.execute(f'PRAGMA locking_mode = EXCLUSIVE')
  59. conn.execute(f'PRAGMA page_size = 4096')
  60. conn.execute(f'PRAGMA mmap_size = {4*1024*1024}')
  61. conn.execute(f'PRAGMA cache_size = 10000')
  62. print('reading subset data')
  63. subset_df = pd.read_csv(args.subset, sep='\t', usecols=[1, 2], names=['photoid', 'photo_hash'], index_col='photoid')
  64. subset_df.to_sql(subset_name, con=conn, if_exists='replace')
  65. print('overwriting with subset')
  66. select_query = f'select {fullset_name}.*, {subset_name}.photo_hash from {fullset_name} inner join {subset_name} on {fullset_name}.photoid = {subset_name}.photoid'
  67. new_name = 'yfcc100m_dataset_new'
  68. print('creating new table')
  69. conn.execute(f'drop table if exists {new_name}')
  70. conn.execute(' '.join([f'create table {new_name} as ', select_query]))
  71. print(f'droping {fullset_name}')
  72. conn.execute(f'drop table if exists {fullset_name}')
  73. print(f'droping {subset_name}')
  74. conn.execute(f'drop table if exists {subset_name}')
  75. print(f'renaming {new_name} to {fullset_name}')
  76. conn.execute(f'alter table {new_name} rename to {fullset_name}')
  77. print('vacuuming db')
  78. conn.execute('vacuum')
  79. print(f'Loading dataframe from SQL')
  80. anno_df = pd.read_sql(f'select * from {fullset_name}', con=conn)
  81. print(f'Loaded dataframe from SQL: \n{anno_df.head()}')
  82. print(f'Length: \n{len(anno_df)}')
  83. print(f'generating filepath')
  84. anno_df['file'] = anno_df['photo_hash'].parallel_map(key2path)
  85. anno_df['caption'] = anno_df['description'].parallel_map(clean_caption)
  86. anno_df = anno_df[['file', 'caption']]
  87. print(f'Generated dataframe: \n{anno_df.head()}')
  88. print('saving subset as tsv')
  89. os.makedirs(args.output_dir, exist_ok=True)
  90. anno_df.to_csv(osp.join(args.output_dir, 'yfcc14m_dataset.tsv'), sep='\t', index=False)
  91. conn.close()
  92. if __name__ == '__main__':
  93. main()