main_group_vit.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. # -------------------------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. #
  5. # MIT License
  6. #
  7. # Permission is hereby granted, free of charge, to any person obtaining a copy
  8. # of this software and associated documentation files (the "Software"), to deal
  9. # in the Software without restriction, including without limitation the rights
  10. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11. # copies of the Software, and to permit persons to whom the Software is
  12. # furnished to do so, subject to the following conditions:
  13. #
  14. # The above copyright notice and this permission notice shall be included in all
  15. # copies or substantial portions of the Software.
  16. #
  17. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  23. # SOFTWARE
  24. #
  25. # Written by Ze Liu, Zhenda Xie
  26. # Modified by Jiarui Xu
  27. # -------------------------------------------------------------------------
  28. import argparse
  29. import datetime
  30. import os
  31. import os.path as osp
  32. import time
  33. from collections import defaultdict
  34. import torch
  35. import torch.backends.cudnn as cudnn
  36. import torch.distributed as dist
  37. import torch.multiprocessing as mp
  38. from datasets import build_loader, build_text_transform, imagenet_classes
  39. from mmcv.parallel import MMDistributedDataParallel
  40. from mmcv.runner import get_dist_info, init_dist, set_random_seed
  41. from mmcv.utils import collect_env, get_git_hash
  42. from mmseg.apis import multi_gpu_test
  43. from models import build_model
  44. from omegaconf import OmegaConf, read_write
  45. from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
  46. from datasets.build import build_dataloader
  47. from timm.utils import AverageMeter, accuracy
  48. from utils import (auto_resume_helper, build_dataset_class_tokens, build_optimizer, build_scheduler, data2cuda,
  49. get_config, get_grad_norm, get_logger, load_checkpoint, parse_losses, reduce_tensor, save_checkpoint)
  50. from tools.cfg2arg import cfg2arg
  51. from utils.metrics import Evaluator
  52. try:
  53. # noinspection PyUnresolvedReferences
  54. from apex import amp
  55. except ImportError:
  56. amp = None
  57. def parse_args():
  58. parser = argparse.ArgumentParser('GroupViT training and evaluation script')
  59. parser.add_argument('--cfg', type=str, required=True, help='path to config file')
  60. parser.add_argument('--opts', help="Modify config options by adding 'KEY=VALUE' list. ", default=None, nargs='+')
  61. # easy config modification
  62. parser.add_argument('--batch-size', type=int, help='batch size for single GPU')
  63. parser.add_argument('--resume', help='resume from checkpoint')
  64. parser.add_argument(
  65. '--amp-opt-level',
  66. type=str,
  67. default='O1',
  68. choices=['O0', 'O1', 'O2'],
  69. help='mixed precision opt level, if O0, no amp is used')
  70. parser.add_argument(
  71. '--output', type=str, help='root of output folder, '
  72. 'the full path is <output>/<model_name>/<tag>')
  73. parser.add_argument('--tag', type=str, help='tag of experiment')
  74. parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
  75. parser.add_argument('--wandb', action='store_true', help='Use W&B to log experiments')
  76. parser.add_argument('--keep', type=int, help='Maximum checkpoint to keep')
  77. # distributed training
  78. parser.add_argument('--local_rank', type=int, required=True, help='local rank for DistributedDataParallel')
  79. args = parser.parse_args()
  80. return args
  81. def train(cfg):
  82. if cfg.wandb and dist.get_rank() == 0:
  83. import wandb
  84. wandb.init(
  85. project='group_vit',
  86. name=osp.join(cfg.model_name, cfg.tag),
  87. dir=cfg.output,
  88. config=OmegaConf.to_container(cfg, resolve=True),
  89. resume=cfg.checkpoint.auto_resume)
  90. else:
  91. wandb = None
  92. # waiting wandb init
  93. dist.barrier()
  94. dataset_train, dataset_val, \
  95. data_loader_train, data_loader_val = build_loader(cfg.data)
  96. data_loader_seg = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
  97. print("\n\n\n")
  98. print(cfg)
  99. print("\n\n\n")
  100. # get image-text pair datasets dataloader
  101. # train_loader, val_img_loader, val_txt_loader, num_classes = build_dataloader(cfg)
  102. val_img_loader, val_txt_loader, num_classes = build_dataloader(cfg)
  103. logger = get_logger()
  104. logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
  105. model = build_model(cfg.model)
  106. # # load_checkpoint(cfg, model, None, None)
  107. # # 冻结所有层
  108. # for param in model.parameters():
  109. # param.requires_grad = False
  110. # # 如果你只想冻结特定的层,可以按照以下方式进行
  111. # # 例如,冻结所有的 img_projector 层
  112. # for param in model.img_projector.parameters():
  113. # param.requires_grad = True
  114. # # 如果你只想冻结特定的层,可以按照以下方式进行
  115. # # 例如,冻结所有的 text_projector 层
  116. # for param in model.text_projector.parameters():
  117. # param.requires_grad = True
  118. model.cuda()
  119. logger.info(str(model))
  120. optimizer = build_optimizer(cfg.train, model)
  121. if cfg.train.amp_opt_level != 'O0':
  122. model, optimizer = amp.initialize(model, optimizer, opt_level=cfg.train.amp_opt_level)
  123. model = MMDistributedDataParallel(model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  124. model_without_ddp = model.module
  125. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  126. logger.info(f'number of params: {n_parameters}')
  127. lr_scheduler = build_scheduler(cfg.train, optimizer, len(data_loader_train))
  128. if cfg.checkpoint.auto_resume:
  129. resume_file = auto_resume_helper(cfg.output)
  130. if resume_file:
  131. if cfg.checkpoint.resume:
  132. logger.warning(f'auto-resume changing resume file from {cfg.checkpoint.resume} to {resume_file}')
  133. with read_write(cfg):
  134. cfg.checkpoint.resume = resume_file
  135. logger.info(f'auto resuming from {resume_file}')
  136. else:
  137. logger.info(f'no checkpoint found in {cfg.output}, ignoring auto resume')
  138. max_accuracy = max_miou = max_rank1 = 0.0
  139. max_metrics = {'max_accuracy': max_accuracy, 'max_miou': max_miou, 'max_rank1': max_rank1}
  140. if cfg.checkpoint.resume:
  141. max_metrics = load_checkpoint(cfg, model_without_ddp, optimizer, lr_scheduler)
  142. max_accuracy, max_miou = max_metrics['max_accuracy'], max_metrics['max_miou']
  143. if 'cls' in cfg.evaluate.task:
  144. acc1, acc5, loss = validate_cls(cfg, data_loader_val, model)
  145. logger.info(f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%')
  146. if 'seg' in cfg.evaluate.task:
  147. miou = validate_seg(cfg, data_loader_seg, model)
  148. logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
  149. if 'reid' in cfg.evaluate.task:
  150. # mrank1 = validate_reid(cfg, data_loader_reid, model)
  151. mrank1 = validate_reid(cfg, val_img_loader, val_txt_loader, model)
  152. # logger.info(f'Rank1 of the network on the {len(data_loader_reid)} test images: {mrank1:.2f}%')
  153. logger.info(f'Rank1 of the network on the {len(val_img_loader)} test images: {mrank1:.2f}%')
  154. if cfg.evaluate.eval_only:
  155. return
  156. logger.info('Start training')
  157. start_time = time.time()
  158. for epoch in range(cfg.train.start_epoch, cfg.train.epochs):
  159. loss_train_dict = train_one_epoch(cfg, model, data_loader_train, optimizer, epoch, lr_scheduler)
  160. if dist.get_rank() == 0 and (epoch % cfg.checkpoint.save_freq == 0 or epoch == (cfg.train.epochs - 1)):
  161. save_checkpoint(cfg, epoch, model_without_ddp, {
  162. 'max_accuracy': max_accuracy,
  163. 'max_miou': max_miou,
  164. 'max_rank1': max_rank1
  165. }, optimizer, lr_scheduler)
  166. dist.barrier()
  167. loss_train = loss_train_dict['total_loss']
  168. logger.info(f'Avg loss of the network on the {len(dataset_train)} train images: {loss_train:.2f}')
  169. # evaluate
  170. if (epoch % cfg.evaluate.eval_freq == 0 or epoch == (cfg.train.epochs - 1)):
  171. if 'cls' in cfg.evaluate.task:
  172. acc1, acc5, loss = validate_cls(cfg, data_loader_val, model)
  173. logger.info(f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%')
  174. max_metrics['max_accuracy'] = max(max_metrics['max_accuracy'], acc1)
  175. if cfg.evaluate.cls.save_best and dist.get_rank() == 0 and acc1 > max_accuracy:
  176. save_checkpoint(
  177. cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_acc1')
  178. dist.barrier()
  179. max_accuracy = max_metrics['max_accuracy']
  180. logger.info(f'Max accuracy: {max_accuracy:.2f}%')
  181. if 'seg' in cfg.evaluate.task:
  182. miou = validate_seg(cfg, data_loader_seg, model)
  183. logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
  184. max_metrics['max_miou'] = max(max_metrics['max_miou'], miou)
  185. if cfg.evaluate.seg.save_best and dist.get_rank() == 0 and miou > max_miou:
  186. save_checkpoint(
  187. cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_miou')
  188. dist.barrier()
  189. max_miou = max_metrics['max_miou']
  190. logger.info(f'Max mIoU: {max_miou:.2f}%')
  191. if 'reid' in cfg.evaluate.task:
  192. mrank1 = validate_reid(cfg, val_img_loader, val_txt_loader, model)
  193. logger.info(f'mRank1 of the network on the {len(val_img_loader)} test images: {mrank1:.2f}%')
  194. max_metrics['max_rank1'] = max(max_metrics['max_rank1'], mrank1)
  195. if cfg.evaluate.reid.save_best and dist.get_rank() == 0 and mrank1 > max_rank1:
  196. save_checkpoint(
  197. cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_rank1')
  198. dist.barrier()
  199. max_rank1 = max_metrics['max_rank1']
  200. logger.info(f'Max mRank1: {max_rank1:.2f}%')
  201. if wandb is not None:
  202. log_stat = {f'epoch/train_{k}': v for k, v in loss_train_dict.items()}
  203. log_stat.update({
  204. 'epoch/val_acc1': acc1,
  205. 'epoch/val_acc5': acc5,
  206. 'epoch/val_loss': loss,
  207. 'epoch/val_miou': miou,
  208. 'epoch/val_rank1': mrank1,
  209. 'epoch/epoch': epoch,
  210. 'epoch/n_parameters': n_parameters
  211. })
  212. wandb.log(log_stat)
  213. total_time = time.time() - start_time
  214. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  215. logger.info('Training time {}'.format(total_time_str))
  216. dist.barrier()
  217. def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler):
  218. logger = get_logger()
  219. dist.barrier()
  220. model.train()
  221. optimizer.zero_grad()
  222. if config.wandb and dist.get_rank() == 0:
  223. import wandb
  224. else:
  225. wandb = None
  226. num_steps = len(data_loader)
  227. batch_time = AverageMeter()
  228. loss_meter = AverageMeter()
  229. norm_meter = AverageMeter()
  230. log_vars_meters = defaultdict(AverageMeter)
  231. start = time.time()
  232. end = time.time()
  233. for idx, samples in enumerate(data_loader):
  234. # print('\n\n1\n\n')
  235. batch_size = config.data.batch_size
  236. losses = model(**samples)
  237. loss, log_vars = parse_losses(losses)
  238. if config.train.accumulation_steps > 1:
  239. loss = loss / config.train.accumulation_steps
  240. if config.train.amp_opt_level != 'O0':
  241. with amp.scale_loss(loss, optimizer) as scaled_loss:
  242. scaled_loss.backward()
  243. if config.train.clip_grad:
  244. grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.train.clip_grad)
  245. else:
  246. grad_norm = get_grad_norm(amp.master_params(optimizer))
  247. else:
  248. loss.backward()
  249. if config.train.clip_grad:
  250. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad)
  251. else:
  252. grad_norm = get_grad_norm(model.parameters())
  253. if (idx + 1) % config.train.accumulation_steps == 0:
  254. optimizer.step()
  255. optimizer.zero_grad()
  256. lr_scheduler.step_update(epoch * num_steps + idx)
  257. else:
  258. optimizer.zero_grad()
  259. if config.train.amp_opt_level != 'O0':
  260. with amp.scale_loss(loss, optimizer) as scaled_loss:
  261. scaled_loss.backward()
  262. if config.train.clip_grad:
  263. grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.train.clip_grad)
  264. else:
  265. grad_norm = get_grad_norm(amp.master_params(optimizer))
  266. else:
  267. loss.backward()
  268. if config.train.clip_grad:
  269. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad)
  270. else:
  271. grad_norm = get_grad_norm(model.parameters())
  272. optimizer.step()
  273. lr_scheduler.step_update(epoch * num_steps + idx)
  274. torch.cuda.synchronize()
  275. # print('\n\n2\n\n')
  276. loss_meter.update(loss.item(), batch_size)
  277. for loss_name in log_vars:
  278. log_vars_meters[loss_name].update(log_vars[loss_name], batch_size)
  279. norm_meter.update(grad_norm)
  280. batch_time.update(time.time() - end)
  281. end = time.time()
  282. # print('\n\n3\n\n')
  283. if idx % config.print_freq == 0:
  284. # print('\n\n4\n\n')
  285. lr = optimizer.param_groups[0]['lr']
  286. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  287. etas = batch_time.avg * (num_steps - idx)
  288. log_vars_str = '\t'.join(f'{n} {m.val:.4f} ({m.avg:.4f})' for n, m in log_vars_meters.items())
  289. logger.info(f'Train: [{epoch}/{config.train.epochs}][{idx}/{num_steps}]\t'
  290. f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
  291. f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
  292. f'total_loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  293. f'{log_vars_str}\t'
  294. f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
  295. f'mem {memory_used:.0f}MB')
  296. # print('\n\n5\n\n')
  297. if wandb is not None:
  298. log_stat = {f'iter/train_{n}': m.avg for n, m in log_vars_meters.items()}
  299. log_stat['iter/train_total_loss'] = loss_meter.avg
  300. log_stat['iter/learning_rate'] = lr
  301. wandb.log(log_stat)
  302. # print('\n\n6\n\n')
  303. epoch_time = time.time() - start
  304. logger.info(f'EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}')
  305. result_dict = dict(total_loss=loss_meter.avg)
  306. for n, m in log_vars_meters.items():
  307. result_dict[n] = m.avg
  308. dist.barrier()
  309. return result_dict
  310. @torch.no_grad()
  311. def validate_cls(config, data_loader, model):
  312. logger = get_logger()
  313. dist.barrier()
  314. criterion = torch.nn.CrossEntropyLoss()
  315. model.eval()
  316. batch_time = AverageMeter()
  317. loss_meter = AverageMeter()
  318. acc1_meter = AverageMeter()
  319. acc5_meter = AverageMeter()
  320. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  321. end = time.time()
  322. logger.info('Building zero shot classifier')
  323. text_embedding = data2cuda(
  324. model.module.build_text_embedding(
  325. build_dataset_class_tokens(text_transform, config.evaluate.cls.template, imagenet_classes)))
  326. logger.info('Zero shot classifier built')
  327. for idx, samples in enumerate(data_loader):
  328. target = samples.pop('target').data[0].cuda()
  329. target = data2cuda(target)
  330. # compute output
  331. output = model(**samples, text=text_embedding)
  332. # measure accuracy and record loss
  333. loss = criterion(output, target)
  334. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  335. acc1 = reduce_tensor(acc1)
  336. acc5 = reduce_tensor(acc5)
  337. loss = reduce_tensor(loss)
  338. loss_meter.update(loss.item(), target.size(0))
  339. acc1_meter.update(acc1.item(), target.size(0))
  340. acc5_meter.update(acc5.item(), target.size(0))
  341. # measure elapsed time
  342. batch_time.update(time.time() - end)
  343. end = time.time()
  344. if idx % config.print_freq == 0:
  345. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  346. logger.info(f'Test: [{idx}/{len(data_loader)}]\t'
  347. f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  348. f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  349. f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
  350. f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
  351. f'Mem {memory_used:.0f}MB')
  352. logger.info('Clearing zero shot classifier')
  353. torch.cuda.empty_cache()
  354. logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
  355. dist.barrier()
  356. return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
  357. @torch.no_grad()
  358. def validate_seg(config, data_loader, model):
  359. logger = get_logger()
  360. dist.barrier()
  361. model.eval()
  362. if hasattr(model, 'module'):
  363. model_without_ddp = model.module
  364. else:
  365. model_without_ddp = model
  366. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  367. seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
  368. mmddp_model = MMDistributedDataParallel(
  369. seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  370. mmddp_model.eval()
  371. results = multi_gpu_test(
  372. model=mmddp_model,
  373. data_loader=data_loader,
  374. tmpdir=None,
  375. gpu_collect=True,
  376. efficient_test=False,
  377. pre_eval=True,
  378. format_only=False)
  379. if dist.get_rank() == 0:
  380. metric = [data_loader.dataset.evaluate(results, metric='mIoU')]
  381. else:
  382. metric = [None]
  383. dist.broadcast_object_list(metric)
  384. miou_result = metric[0]['mIoU'] * 100
  385. torch.cuda.empty_cache()
  386. logger.info(f'Eval Seg mIoU {miou_result:.2f}')
  387. dist.barrier()
  388. return miou_result
  389. @torch.no_grad()
  390. def validate_reid(cfg, img_loader, txt_loader, model):
  391. logger = get_logger()
  392. dist.barrier()
  393. # model.eval()
  394. evaluator = Evaluator(img_loader, txt_loader)
  395. if hasattr(model, 'module'):
  396. model_without_ddp = model.module
  397. else:
  398. model_without_ddp = model
  399. # reid_model = build_reid_inference(model_without_ddp, img_loader, txt_loader, cfg.evaluate.reid)
  400. # mmddp_model = MMDistributedDataParallel(
  401. # reid_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  402. rank1 = evaluator.eval(model_without_ddp.eval())
  403. # results = multi_gpu_test(
  404. # model=mmddp_model,
  405. # data_loader=img_loader,
  406. # tmpdir=None,
  407. # gpu_collect=True,
  408. # efficient_test=False,
  409. # pre_eval=True,
  410. # format_only=False)
  411. # if dist.get_rank() == 0:
  412. # metric = [img_loader.dataset.evaluate(results, metric='Rank-1')]
  413. # else:
  414. # metric = [None]
  415. # dist.broadcast_object_list(metric)
  416. # rank1_result = metric[0]['Rank-1'] * 100
  417. torch.cuda.empty_cache()
  418. dist.barrier()
  419. return rank1
  420. def main():
  421. args = parse_args()
  422. cfg = get_config(args)
  423. if cfg.train.amp_opt_level != 'O0':
  424. assert amp is not None, 'amp not installed!'
  425. # start faster ref: https://github.com/open-mmlab/mmdetection/pull/7036
  426. mp.set_start_method('fork', force=True)
  427. init_dist('pytorch')
  428. rank, world_size = get_dist_info()
  429. print(f'RANK and WORLD_SIZE in environ: {rank}/{world_size}')
  430. dist.barrier()
  431. set_random_seed(cfg.seed, use_rank_shift=True)
  432. cudnn.benchmark = True
  433. os.makedirs(cfg.output, exist_ok=True)
  434. logger = get_logger(cfg)
  435. # linear scale the learning rate according to total batch size, may not be optimal
  436. linear_scaled_lr = cfg.train.base_lr * cfg.data.batch_size * world_size / 4096.0
  437. linear_scaled_warmup_lr = cfg.train.warmup_lr * cfg.data.batch_size * world_size / 4096.0
  438. linear_scaled_min_lr = cfg.train.min_lr * cfg.data.batch_size * world_size / 4096.0
  439. # gradient accumulation also need to scale the learning rate
  440. if cfg.train.accumulation_steps > 1:
  441. linear_scaled_lr = linear_scaled_lr * cfg.train.accumulation_steps
  442. linear_scaled_warmup_lr = linear_scaled_warmup_lr * cfg.train.accumulation_steps
  443. linear_scaled_min_lr = linear_scaled_min_lr * cfg.train.accumulation_steps
  444. with read_write(cfg):
  445. logger.info(f'Scale base_lr from {cfg.train.base_lr} to {linear_scaled_lr}')
  446. logger.info(f'Scale warmup_lr from {cfg.train.warmup_lr} to {linear_scaled_warmup_lr}')
  447. logger.info(f'Scale min_lr from {cfg.train.min_lr} to {linear_scaled_min_lr}')
  448. cfg.train.base_lr = linear_scaled_lr
  449. cfg.train.warmup_lr = linear_scaled_warmup_lr
  450. cfg.train.min_lr = linear_scaled_min_lr
  451. if dist.get_rank() == 0:
  452. path = os.path.join(cfg.output, 'config.json')
  453. OmegaConf.save(cfg, path)
  454. logger.info(f'Full config saved to {path}')
  455. # log env info
  456. env_info_dict = collect_env()
  457. env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
  458. dash_line = '-' * 60 + '\n'
  459. logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line)
  460. logger.info(f'Git hash: {get_git_hash(digits=7)}')
  461. # print config
  462. logger.info(OmegaConf.to_yaml(cfg))
  463. train(cfg)
  464. dist.barrier()
  465. if __name__ == '__main__':
  466. main()