processor.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import logging
  2. import time
  3. import torch
  4. from utils.meter import AverageMeter
  5. from utils.metrics import Evaluator
  6. from utils.comm import get_rank, synchronize
  7. from torch.utils.tensorboard import SummaryWriter
  8. from prettytable import PrettyTable
  9. def do_train(start_epoch, args, model, train_loader, evaluator, optimizer,
  10. scheduler, checkpointer):
  11. log_period = args.log_period
  12. eval_period = args.eval_period
  13. device = "cuda"
  14. num_epoch = args.num_epoch
  15. arguments = {}
  16. arguments["num_epoch"] = num_epoch
  17. arguments["iteration"] = 0
  18. logger = logging.getLogger("IRRA.train")
  19. logger.info('start training')
  20. meters = {
  21. "loss": AverageMeter(),
  22. "sdm_loss": AverageMeter(),
  23. "itc_loss": AverageMeter(),
  24. "id_loss": AverageMeter(),
  25. "mlm_loss": AverageMeter(),
  26. "img_acc": AverageMeter(),
  27. "txt_acc": AverageMeter(),
  28. "mlm_acc": AverageMeter()
  29. }
  30. tb_writer = SummaryWriter(log_dir=args.output_dir)
  31. best_top1 = 0.0
  32. # train
  33. for epoch in range(start_epoch, num_epoch + 1):
  34. start_time = time.time()
  35. for meter in meters.values():
  36. meter.reset()
  37. model.train()
  38. for n_iter, batch in enumerate(train_loader):
  39. batch = {k: v.to(device) for k, v in batch.items()}
  40. ret = model(batch)
  41. total_loss = sum([v for k, v in ret.items() if "loss" in k])
  42. batch_size = batch['images'].shape[0]
  43. meters['loss'].update(total_loss.item(), batch_size)
  44. meters['sdm_loss'].update(ret.get('sdm_loss', 0), batch_size)
  45. meters['itc_loss'].update(ret.get('itc_loss', 0), batch_size)
  46. meters['id_loss'].update(ret.get('id_loss', 0), batch_size)
  47. meters['mlm_loss'].update(ret.get('mlm_loss', 0), batch_size)
  48. meters['img_acc'].update(ret.get('img_acc', 0), batch_size)
  49. meters['txt_acc'].update(ret.get('txt_acc', 0), batch_size)
  50. meters['mlm_acc'].update(ret.get('mlm_acc', 0), 1)
  51. optimizer.zero_grad()
  52. total_loss.backward()
  53. optimizer.step()
  54. synchronize()
  55. if (n_iter + 1) % log_period == 0:
  56. info_str = f"Epoch[{epoch}] Iteration[{n_iter + 1}/{len(train_loader)}]"
  57. # log loss and acc info
  58. for k, v in meters.items():
  59. if v.avg > 0:
  60. info_str += f", {k}: {v.avg:.4f}"
  61. info_str += f", Base Lr: {scheduler.get_lr()[0]:.2e}"
  62. logger.info(info_str)
  63. tb_writer.add_scalar('lr', scheduler.get_lr()[0], epoch)
  64. tb_writer.add_scalar('temperature', ret['temperature'], epoch)
  65. for k, v in meters.items():
  66. if v.avg > 0:
  67. tb_writer.add_scalar(k, v.avg, epoch)
  68. scheduler.step()
  69. if get_rank() == 0:
  70. end_time = time.time()
  71. time_per_batch = (end_time - start_time) / (n_iter + 1)
  72. logger.info(
  73. "Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]"
  74. .format(epoch, time_per_batch,
  75. train_loader.batch_size / time_per_batch))
  76. if epoch % eval_period == 0:
  77. if get_rank() == 0:
  78. logger.info("Validation Results - Epoch: {}".format(epoch))
  79. if args.distributed:
  80. top1 = evaluator.eval(model.module.eval())
  81. else:
  82. top1 = evaluator.eval(model.eval())
  83. torch.cuda.empty_cache()
  84. if best_top1 < top1:
  85. best_top1 = top1
  86. arguments["epoch"] = epoch
  87. checkpointer.save("best", **arguments)
  88. if get_rank() == 0:
  89. logger.info(f"best R1: {best_top1} at epoch {arguments['epoch']}")
  90. def do_inference(model, test_img_loader, test_txt_loader):
  91. logger = logging.getLogger("IRRA.test")
  92. logger.info("Enter inferencing")
  93. evaluator = Evaluator(test_img_loader, test_txt_loader)
  94. top1 = evaluator.eval(model.eval())