classify_val_images.py 1.2 KB

123456789101112131415161718192021222324252627282930313233
  1. import os
  2. import shutil
  3. import torch
  4. # 设置根目录路径
  5. root_directory = "/mnt/vos-s9gjtkm2/reid/dataset/imagenet"
  6. val_directory = os.path.join(root_directory, "val")
  7. val_images_directory = val_directory
  8. # 读取验证集的 ground truth 文件
  9. ground_truth_file = os.path.join(root_directory, "ILSVRC2012_devkit_t12", "data", "ILSVRC2012_validation_ground_truth.txt")
  10. with open(ground_truth_file, "r") as f:
  11. val_labels = f.readlines()
  12. # 读取 meta.bin 文件以获取类别信息
  13. meta_file = os.path.join(root_directory, "meta.bin")
  14. wnid_to_classes, val_wnids = torch.load(meta_file)
  15. # 创建类别目录
  16. for wnid in wnid_to_classes.keys():
  17. class_directory = os.path.join(val_directory, wnid)
  18. if not os.path.exists(class_directory):
  19. os.makedirs(class_directory)
  20. # 将图片按类别分类
  21. for i, label in enumerate(val_labels):
  22. label = label.strip()
  23. wnid = val_wnids[i]
  24. src_file = os.path.join(val_images_directory, f"ILSVRC2012_val_{i + 1:08d}.JPEG")
  25. dest_file = os.path.join(val_directory, wnid, f"ILSVRC2012_val_{i + 1:08d}.JPEG")
  26. shutil.move(src_file, dest_file)
  27. print("Validation images have been classified by category.")