build.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import torch
  2. from .lr_scheduler import LRSchedulerWithWarmup
  3. def build_optimizer(args, model):
  4. params = []
  5. print(f'Using {args.lr_factor} times learning rate for random init module ')
  6. for key, value in model.named_parameters():
  7. if not value.requires_grad:
  8. continue
  9. lr = args.lr
  10. weight_decay = args.weight_decay
  11. if "cross" in key:
  12. # use large learning rate for random initialized cross modal module
  13. lr = args.lr * args.lr_factor # default 5.0
  14. if "bias" in key:
  15. lr = args.lr * args.bias_lr_factor
  16. weight_decay = args.weight_decay_bias
  17. if "classifier" in key or "mlm_head" in key:
  18. lr = args.lr * args.lr_factor
  19. params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
  20. if args.optimizer == "SGD":
  21. optimizer = torch.optim.SGD(
  22. params, lr=args.lr, momentum=args.momentum
  23. )
  24. elif args.optimizer == "Adam":
  25. optimizer = torch.optim.Adam(
  26. params,
  27. lr=args.lr,
  28. betas=(args.alpha, args.beta),
  29. eps=1e-3,
  30. )
  31. elif args.optimizer == "AdamW":
  32. optimizer = torch.optim.AdamW(
  33. params,
  34. lr=args.lr,
  35. betas=(args.alpha, args.beta),
  36. eps=1e-8,
  37. )
  38. else:
  39. NotImplementedError
  40. return optimizer
  41. def build_lr_scheduler(args, optimizer):
  42. return LRSchedulerWithWarmup(
  43. optimizer,
  44. milestones=args.milestones,
  45. gamma=args.gamma,
  46. warmup_factor=args.warmup_factor,
  47. warmup_epochs=args.warmup_epochs,
  48. warmup_method=args.warmup_method,
  49. total_epochs=args.num_epoch,
  50. mode=args.lrscheduler,
  51. target_lr=args.target_lr,
  52. power=args.power,
  53. )