1234567891011121314151617181920212223242526272829303132333435 |
- from prettytable import PrettyTable
- import os
- # os.environ['CUDA_VISIBLE_DEVICES'] = '3'
- import torch
- import numpy as np
- import time
- import os.path as op
- from datasets import build_dataloader
- from processor.processor import do_inference
- from utils.checkpoint import Checkpointer
- from utils.logger import setup_logger
- from model import build_model
- from utils.metrics import Evaluator
- import argparse
- from utils.iotools import load_train_configs
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description="IRRA Test")
- parser.add_argument("--config_file", default='logs/CUHK-PEDES/iira/configs.yaml')
- args = parser.parse_args()
- args = load_train_configs(args.config_file)
- args.training = False
- logger = setup_logger('IRRA', save_dir=args.output_dir, if_train=args.training)
- logger.info(args)
- device = "cuda"
- test_img_loader, test_txt_loader, num_classes = build_dataloader(args)
- model = build_model(args, num_classes=num_classes)
- checkpointer = Checkpointer(model)
- checkpointer.load(f=op.join(args.output_dir, 'best.pth'))
- model.to(device)
- do_inference(model, test_img_loader, test_txt_loader)
|