12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- import torch
- from .lr_scheduler import LRSchedulerWithWarmup
- def build_optimizer(args, model):
- params = []
- print(f'Using {args.lr_factor} times learning rate for random init module ')
-
- for key, value in model.named_parameters():
- if not value.requires_grad:
- continue
- lr = args.lr
- weight_decay = args.weight_decay
- if "cross" in key:
- # use large learning rate for random initialized cross modal module
- lr = args.lr * args.lr_factor # default 5.0
- if "bias" in key:
- lr = args.lr * args.bias_lr_factor
- weight_decay = args.weight_decay_bias
- if "classifier" in key or "mlm_head" in key:
- lr = args.lr * args.lr_factor
-
- params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
- if args.optimizer == "SGD":
- optimizer = torch.optim.SGD(
- params, lr=args.lr, momentum=args.momentum
- )
- elif args.optimizer == "Adam":
- optimizer = torch.optim.Adam(
- params,
- lr=args.lr,
- betas=(args.alpha, args.beta),
- eps=1e-3,
- )
- elif args.optimizer == "AdamW":
- optimizer = torch.optim.AdamW(
- params,
- lr=args.lr,
- betas=(args.alpha, args.beta),
- eps=1e-8,
- )
- else:
- NotImplementedError
- return optimizer
- def build_lr_scheduler(args, optimizer):
- return LRSchedulerWithWarmup(
- optimizer,
- milestones=args.milestones,
- gamma=args.gamma,
- warmup_factor=args.warmup_factor,
- warmup_epochs=args.warmup_epochs,
- warmup_method=args.warmup_method,
- total_epochs=args.num_epoch,
- mode=args.lrscheduler,
- target_lr=args.target_lr,
- power=args.power,
- )
|