test.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435
  1. from prettytable import PrettyTable
  2. import os
  3. # os.environ['CUDA_VISIBLE_DEVICES'] = '3'
  4. import torch
  5. import numpy as np
  6. import time
  7. import os.path as op
  8. from datasets import build_dataloader
  9. from processor.processor import do_inference
  10. from utils.checkpoint import Checkpointer
  11. from utils.logger import setup_logger
  12. from model import build_model
  13. from utils.metrics import Evaluator
  14. import argparse
  15. from utils.iotools import load_train_configs
  16. if __name__ == '__main__':
  17. parser = argparse.ArgumentParser(description="IRRA Test")
  18. parser.add_argument("--config_file", default='logs/CUHK-PEDES/iira/configs.yaml')
  19. args = parser.parse_args()
  20. args = load_train_configs(args.config_file)
  21. args.training = False
  22. logger = setup_logger('IRRA', save_dir=args.output_dir, if_train=args.training)
  23. logger.info(args)
  24. device = "cuda"
  25. test_img_loader, test_txt_loader, num_classes = build_dataloader(args)
  26. model = build_model(args, num_classes=num_classes)
  27. checkpointer = Checkpointer(model)
  28. checkpointer.load(f=op.join(args.output_dir, 'best.pth'))
  29. model.to(device)
  30. do_inference(model, test_img_loader, test_txt_loader)