main_pretrain.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. # -------------------------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. #
  5. # MIT License
  6. #
  7. # Permission is hereby granted, free of charge, to any person obtaining a copy
  8. # of this software and associated documentation files (the "Software"), to deal
  9. # in the Software without restriction, including without limitation the rights
  10. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11. # copies of the Software, and to permit persons to whom the Software is
  12. # furnished to do so, subject to the following conditions:
  13. #
  14. # The above copyright notice and this permission notice shall be included in all
  15. # copies or substantial portions of the Software.
  16. #
  17. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  23. # SOFTWARE
  24. #
  25. # Written by Ze Liu, Zhenda Xie
  26. # Modified by Jiarui Xu
  27. # Modified by Jilan Xu
  28. # -------------------------------------------------------------------------
  29. import argparse
  30. import datetime
  31. import os
  32. import os.path as osp
  33. import time
  34. from collections import defaultdict
  35. import subprocess
  36. import time
  37. import torch
  38. import torch.backends.cudnn as cudnn
  39. import torch.distributed as dist
  40. import torch.multiprocessing as mp
  41. from datasets import build_loader, build_text_transform, imagenet_classes
  42. from mmcv.parallel import MMDistributedDataParallel
  43. from mmcv.runner import get_dist_info, init_dist, set_random_seed
  44. from mmcv.utils import collect_env, get_git_hash
  45. from mmseg.apis import multi_gpu_test
  46. from models import build_model
  47. from omegaconf import OmegaConf, read_write
  48. from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
  49. from timm.utils import AverageMeter, accuracy
  50. from utils import (auto_resume_helper, build_dataset_class_tokens, build_optimizer, build_scheduler, data2cuda,
  51. get_config, get_grad_norm, get_logger, load_checkpoint, parse_losses, reduce_tensor, save_checkpoint, momentum_update,
  52. load_checkpoint_stage1, build_dataset_class_lists,cdist_,
  53. )
  54. from ipdb import set_trace
  55. import numpy as np
  56. from torch.utils.tensorboard import SummaryWriter
  57. from transformers import AutoTokenizer, RobertaTokenizer
  58. from einops import rearrange
  59. tokenizer_dict = {
  60. 'Bert': AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False),
  61. 'TextTransformer': None,
  62. }
  63. try:
  64. # noinspection PyUnresolvedReferences
  65. from apex import amp
  66. except ImportError:
  67. amp = None
  68. def parse_args():
  69. parser = argparse.ArgumentParser('GroupViT training and evaluation script')
  70. parser.add_argument('--cfg', type=str, required=True, help='path to config file')
  71. parser.add_argument('--opts', help="Modify config options by adding 'KEY=VALUE' list. ", default=None, nargs='+')
  72. # easy config modification
  73. parser.add_argument('--batch-size', type=int, help='batch size for single GPU')
  74. parser.add_argument('--resume', help='resume from checkpoint')
  75. parser.add_argument(
  76. '--amp-opt-level',
  77. type=str,
  78. # default='O1',
  79. default='O0',
  80. choices=['O0', 'O1', 'O2'],
  81. help='mixed precision opt level, if O0, no amp is used')
  82. parser.add_argument(
  83. '--output', type=str, help='root of output folder, '
  84. 'the full path is <output>/<model_name>/<tag>')
  85. parser.add_argument('--tag', type=str, help='tag of experiment')
  86. parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
  87. parser.add_argument('--wandb', action='store_true', help='Use W&B to log experiments')
  88. parser.add_argument('--keep', type=int, help='Maximum checkpoint to keep')
  89. # distributed training
  90. parser.add_argument('--local_rank', type=int, required=False, default=0, help='local rank for DistributedDataParallel')
  91. parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
  92. args = parser.parse_args()
  93. return args
  94. def train(cfg):
  95. if cfg.wandb and dist.get_rank() == 0:
  96. import wandb
  97. wandb.init(
  98. project='group_vit',
  99. name=osp.join(cfg.model_name, cfg.tag),
  100. dir=cfg.output,
  101. config=OmegaConf.to_container(cfg, resolve=True),
  102. resume=cfg.checkpoint.auto_resume)
  103. else:
  104. wandb = None
  105. # waiting wandb init
  106. dist.barrier()
  107. dataset_train, dataset_val, \
  108. data_loader_train, data_loader_val = build_loader(cfg.data)
  109. print('Done train/val loader')
  110. data_loader_seg = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
  111. print('Done seg loader')
  112. logger = get_logger()
  113. if dist.get_rank() == 0:
  114. writer = SummaryWriter(cfg.output)
  115. else:
  116. writer = None
  117. logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
  118. model = build_model(cfg.model)
  119. model.cuda()
  120. logger.info(str(model))
  121. optimizer = build_optimizer(cfg.train, model)
  122. if cfg.train.amp_opt_level != 'O0':
  123. model, optimizer = amp.initialize(model, optimizer, opt_level=cfg.train.amp_opt_level)
  124. model = MMDistributedDataParallel(model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=True)
  125. model_without_ddp = model.module
  126. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  127. logger.info(f'number of params: {n_parameters}')
  128. lr_scheduler = build_scheduler(cfg.train, optimizer, len(data_loader_train))
  129. ##### load init params from stage 1 here, before auto resuming ######
  130. if cfg.checkpoint.stage1_checkpoint:
  131. load_checkpoint_stage1(cfg, model_without_ddp)
  132. if cfg.checkpoint.auto_resume:
  133. resume_file = auto_resume_helper(cfg.output)
  134. if resume_file:
  135. if cfg.checkpoint.resume:
  136. logger.warning(f'auto-resume changing resume file from {cfg.checkpoint.resume} to {resume_file}')
  137. with read_write(cfg):
  138. cfg.checkpoint.resume = resume_file
  139. logger.info(f'auto resuming from {resume_file}')
  140. else:
  141. logger.info(f'no checkpoint found in {cfg.output}, ignoring auto resume')
  142. max_accuracy = max_miou = 0.0
  143. max_metrics = {'max_accuracy': max_accuracy, 'max_miou': max_miou}
  144. if cfg.checkpoint.resume:
  145. max_metrics = load_checkpoint(cfg, model_without_ddp, optimizer, lr_scheduler)
  146. max_accuracy, max_miou = max_metrics['max_accuracy'], max_metrics['max_miou']
  147. ############# set tokenizer ##############
  148. global tokenizer
  149. tokenizer = tokenizer_dict[cfg.model.text_encoder.type]
  150. tensorbd_logdir = cfg.output + "/logs"
  151. logger.info('Start training')
  152. start_time = time.time()
  153. for epoch in range(cfg.train.start_epoch, cfg.train.epochs):
  154. ### train model ###
  155. loss_train_dict = train_one_epoch(cfg, model, data_loader_train, optimizer, epoch, lr_scheduler, writer)
  156. if dist.get_rank() == 0 and (epoch % cfg.checkpoint.save_freq == 0 or epoch == (cfg.train.epochs - 1)):
  157. save_checkpoint(cfg, epoch, model_without_ddp, {
  158. 'max_accuracy': max_accuracy,
  159. 'max_miou': max_miou
  160. }, optimizer, lr_scheduler)
  161. dist.barrier()
  162. loss_train = loss_train_dict['total_loss']
  163. logger.info(f'Avg loss of the network on the {len(dataset_train)} train images: {loss_train:.2f}')
  164. # evaluate
  165. if (epoch % cfg.evaluate.eval_freq == 0 or epoch == (cfg.train.epochs - 1)):
  166. if 'cls' in cfg.evaluate.task:
  167. acc1, acc5, loss = validate_cls(cfg, data_loader_val, model)
  168. logger.info(f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%')
  169. max_metrics['max_accuracy'] = max(max_metrics['max_accuracy'], acc1)
  170. # if cfg.evaluate.cls.save_best and dist.get_rank() == 0 and acc1 > max_accuracy:
  171. # save_checkpoint(
  172. # cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_acc1')
  173. dist.barrier()
  174. max_accuracy = max_metrics['max_accuracy']
  175. logger.info(f'Max accuracy: {max_accuracy:.2f}%')
  176. if 'seg' in cfg.evaluate.task:
  177. miou = validate_seg(cfg, data_loader_seg, model, epoch, writer, tokenizer=tokenizer)
  178. logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
  179. max_metrics['max_miou'] = max(max_metrics['max_miou'], miou)
  180. if cfg.evaluate.seg.save_best and dist.get_rank() == 0 and miou > max_miou:
  181. print('ready saving the best iou model')
  182. save_checkpoint(
  183. cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_miou')
  184. dist.barrier()
  185. max_miou = max_metrics['max_miou']
  186. logger.info(f'Max mIoU: {max_miou:.2f}%')
  187. if wandb is not None:
  188. log_stat = {f'epoch/train_{k}': v for k, v in loss_train_dict.items()}
  189. log_stat.update({
  190. 'epoch/val_acc1': acc1,
  191. 'epoch/val_acc5': acc5,
  192. 'epoch/val_loss': loss,
  193. 'epoch/val_miou': miou,
  194. 'epoch/epoch': epoch,
  195. 'epoch/n_parameters': n_parameters
  196. })
  197. wandb.log(log_stat)
  198. total_time = time.time() - start_time
  199. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  200. logger.info('Training time {}'.format(total_time_str))
  201. dist.barrier()
  202. # writer.flush()
  203. def process_text(text_data):
  204. ### we run all the exps with padding=True, meaning padding to the longest caption ###
  205. # text_data = tokenizer(text_data, return_tensors='pt', padding=True,
  206. # truncation=True, max_length=77)
  207. ### this is more memory friendly/load balance if we chunk the padding size to max_length ###
  208. text_data = tokenizer(text_data, return_tensors='pt', padding='max_length',
  209. truncation=True, max_length=77)
  210. text_data = {key: val.cuda() for key, val in text_data.items()}
  211. return text_data
  212. def generate_entity_masks(text_data):
  213. text = text_data['input_ids']
  214. # [b, L]
  215. entity_masks = text.clone()
  216. entity_masks[entity_masks != 103] = 0
  217. entity_masks[entity_masks == 103] = 1
  218. entity_masks = entity_masks.to(text.device)
  219. return entity_masks
  220. def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, writer):
  221. logger = get_logger()
  222. dist.barrier()
  223. model.train()
  224. optimizer.zero_grad()
  225. if config.wandb and dist.get_rank() == 0:
  226. import wandb
  227. else:
  228. wandb = None
  229. num_steps = len(data_loader)
  230. batch_time = AverageMeter()
  231. loss_meter = AverageMeter()
  232. norm_meter = AverageMeter()
  233. log_vars_meters = defaultdict(AverageMeter)
  234. start = time.time()
  235. end = time.time()
  236. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  237. for idx, samples in enumerate(data_loader):
  238. batch_size = config.data.train.batch_size
  239. all_images = samples['image'].cuda()
  240. all_questions = None
  241. entity_labels = entity_masks = None
  242. all_answers = None
  243. if config.model.text_encoder['type'] in ['DistilBert','Bert','BertMedium','Roberta']:
  244. all_texts = process_text(samples['raw_caption'])
  245. if config.data.train.use_entity is True:
  246. all_questions = process_text(samples['raw_question'])
  247. all_answers= process_text(samples['raw_answer'])
  248. entity_masks = generate_entity_masks(all_questions)
  249. elif config.model.text_encoder['type'] not in ['TextTransformer'] and config.data.train.use_entity is True:
  250. all_texts = samples['caption'].cuda()
  251. all_questions = samples['question'].cuda()
  252. all_answers = samples['answer'].cuda()
  253. else:
  254. all_texts = samples['caption'].cuda()
  255. ### for cross-image mask consistency loss ###
  256. all_crossimage = samples['cross_image'].cuda() if 'cross_image' in samples and samples['cross_image'] is not None else None
  257. question_masks = samples['question_mask'].cuda() if 'question_mask' in samples else None
  258. cross_entity = process_text(samples['cross_entity']) if 'cross_entity' in samples and samples['cross_entity'] is not None else None
  259. ### forward and compute loss ###
  260. losses = model(image=all_images, text=all_texts, cross_image=all_crossimage, cross_entity=cross_entity, \
  261. question=all_questions, answer=all_answers, entity_masks=entity_masks, question_masks=question_masks)
  262. loss, log_vars = parse_losses(losses)
  263. if dist.get_rank() == 0:
  264. writer.add_scalar("Total loss", loss, len(data_loader) * epoch + idx)
  265. writer.add_scalar("contrastive loss", losses['loss'], len(data_loader) * epoch + idx)
  266. if 'entity' in losses:
  267. writer.add_scalar("entity loss", losses['entity'], len(data_loader) * epoch + idx)
  268. if 'mask' in losses:
  269. writer.add_scalar("Mask loss", losses['mask'], len(data_loader) * epoch + idx)
  270. writer.add_scalar("lr", optimizer.param_groups[0]['lr'], len(data_loader) * epoch + idx)
  271. if config.train.accumulation_steps > 1:
  272. loss = loss / config.train.accumulation_steps
  273. if config.train.amp_opt_level != 'O0':
  274. with amp.scale_loss(loss, optimizer) as scaled_loss:
  275. scaled_loss.backward()
  276. if config.train.clip_grad:
  277. grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.train.clip_grad)
  278. else:
  279. grad_norm = get_grad_norm(amp.master_params(optimizer))
  280. else:
  281. loss.backward()
  282. if config.train.clip_grad:
  283. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad)
  284. else:
  285. grad_norm = get_grad_norm(model.parameters())
  286. if (idx + 1) % config.train.accumulation_steps == 0:
  287. optimizer.step()
  288. optimizer.zero_grad()
  289. lr_scheduler.step_update(epoch * num_steps + idx)
  290. if config.model.use_maskloss:
  291. maskloss_coeff = 0.99
  292. momentum_update(model.module.img_encoder, model.module.img_encoder_momentum, maskloss_coeff)
  293. else:
  294. optimizer.zero_grad()
  295. if config.train.amp_opt_level != 'O0':
  296. with amp.scale_loss(loss, optimizer) as scaled_loss:
  297. scaled_loss.backward()
  298. if config.train.clip_grad:
  299. grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.train.clip_grad)
  300. else:
  301. grad_norm = get_grad_norm(amp.master_params(optimizer))
  302. else:
  303. loss.backward()
  304. if config.train.clip_grad:
  305. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad)
  306. else:
  307. grad_norm = get_grad_norm(model.parameters())
  308. optimizer.step()
  309. lr_scheduler.step_update(epoch * num_steps + idx)
  310. if config.model.use_maskloss:
  311. maskloss_coeff = 0.99
  312. momentum_update(model.module.img_encoder, model.module.img_encoder_momentum, maskloss_coeff)
  313. torch.cuda.synchronize()
  314. loss_meter.update(loss.item(), batch_size)
  315. for loss_name in log_vars:
  316. log_vars_meters[loss_name].update(log_vars[loss_name], batch_size)
  317. norm_meter.update(grad_norm)
  318. batch_time.update(time.time() - end)
  319. end = time.time()
  320. if idx % config.print_freq == 0:
  321. lr = optimizer.param_groups[0]['lr']
  322. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  323. etas = batch_time.avg * (num_steps - idx)
  324. log_vars_str = '\t'.join(f'{n} {m.val:.4f} ({m.avg:.4f})' for n, m in log_vars_meters.items())
  325. logger.info(f'Train: [{epoch}/{config.train.epochs}][{idx}/{num_steps}]\t'
  326. f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
  327. f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
  328. f'total_loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  329. f'{log_vars_str}\t'
  330. f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
  331. f'mem {memory_used:.0f}MB')
  332. if wandb is not None:
  333. log_stat = {f'iter/train_{n}': m.avg for n, m in log_vars_meters.items()}
  334. log_stat['iter/train_total_loss'] = loss_meter.avg
  335. log_stat['iter/learning_rate'] = lr
  336. wandb.log(log_stat)
  337. epoch_time = time.time() - start
  338. logger.info(f'EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}')
  339. result_dict = dict(total_loss=loss_meter.avg)
  340. for n, m in log_vars_meters.items():
  341. result_dict[n] = m.avg
  342. dist.barrier()
  343. return result_dict
  344. @torch.no_grad()
  345. def validate_cls(config, data_loader, model):
  346. logger = get_logger()
  347. dist.barrier()
  348. criterion = torch.nn.CrossEntropyLoss()
  349. model.eval()
  350. batch_time = AverageMeter()
  351. loss_meter = AverageMeter()
  352. acc1_meter = AverageMeter()
  353. acc5_meter = AverageMeter()
  354. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  355. end = time.time()
  356. logger.info('Building zero shot classifier')
  357. if config.model.text_encoder['type'] in ['DistilBert', 'Bert','BertMedium','Roberta']:
  358. text_embedding = model.module.build_text_embedding(
  359. build_dataset_class_lists(config.evaluate.cls.template, imagenet_classes), tokenizer, len(imagenet_classes))
  360. else:
  361. text_embedding = data2cuda(
  362. model.module.build_text_embedding(
  363. build_dataset_class_tokens(text_transform, config.evaluate.cls.template, imagenet_classes)))
  364. logger.info('Zero shot classifier built')
  365. for idx, samples in enumerate(data_loader):
  366. all_images = samples['image'].cuda()
  367. target = samples['target'].cuda()
  368. output = model(image=all_images, text=text_embedding)
  369. # measure accuracy and record loss
  370. loss = criterion(output, target)
  371. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  372. acc1 = reduce_tensor(acc1)
  373. acc5 = reduce_tensor(acc5)
  374. loss = reduce_tensor(loss)
  375. loss_meter.update(loss.item(), target.size(0))
  376. acc1_meter.update(acc1.item(), target.size(0))
  377. acc5_meter.update(acc5.item(), target.size(0))
  378. # measure elapsed time
  379. batch_time.update(time.time() - end)
  380. end = time.time()
  381. if idx % config.print_freq == 0:
  382. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  383. logger.info(f'Test: [{idx}/{len(data_loader)}]\t'
  384. f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  385. f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  386. f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
  387. f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
  388. f'Mem {memory_used:.0f}MB')
  389. logger.info('Clearing zero shot classifier')
  390. torch.cuda.empty_cache()
  391. logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
  392. dist.barrier()
  393. return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
  394. @torch.no_grad()
  395. def validate_seg(config, data_loader, model, epoch=0, writer=None, tokenizer=None):
  396. logger = get_logger()
  397. dist.barrier()
  398. model.eval()
  399. if hasattr(model, 'module'):
  400. model_without_ddp = model.module
  401. else:
  402. model_without_ddp = model
  403. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  404. if config.model.text_encoder['type'] in ['DistilBert', 'Bert','BertMedium','Roberta']:
  405. seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg, tokenizer)
  406. else:
  407. seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
  408. mmddp_model = MMDistributedDataParallel(
  409. seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  410. mmddp_model.eval()
  411. results = multi_gpu_test(
  412. model=mmddp_model,
  413. data_loader=data_loader,
  414. tmpdir=None,
  415. gpu_collect=True,
  416. efficient_test=False,
  417. pre_eval=True,
  418. format_only=False)
  419. if dist.get_rank() == 0:
  420. metric = [data_loader.dataset.evaluate(results, metric='mIoU')]
  421. else:
  422. metric = [None]
  423. dist.broadcast_object_list(metric)
  424. miou_result = metric[0]['mIoU'] * 100
  425. torch.cuda.empty_cache()
  426. logger.info(f'Eval Seg mIoU {miou_result:.2f}')
  427. if writer is not None and dist.get_rank() == 0:
  428. writer.add_scalar("mIoU", miou_result, epoch)
  429. dist.barrier()
  430. return miou_result
  431. def setup_for_distributed(is_master):
  432. """
  433. This function disables printing when not in master process
  434. """
  435. import builtins as __builtin__
  436. builtin_print = __builtin__.print
  437. def print(*args, **kwargs):
  438. force = kwargs.pop('force', False)
  439. if is_master or force:
  440. builtin_print(*args, **kwargs)
  441. __builtin__.print = print
  442. def init_distributed_mode(args):
  443. # launched with torch.distributed.launch
  444. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  445. args.rank = int(os.environ["RANK"])
  446. args.world_size = int(os.environ['WORLD_SIZE'])
  447. args.gpu = int(os.environ['LOCAL_RANK'])
  448. # launched with submitit on a slurm cluster
  449. elif 'SLURM_PROCID' in os.environ:
  450. proc_id = int(os.environ['SLURM_PROCID'])
  451. ntasks = os.environ['SLURM_NTASKS']
  452. node_list = os.environ['SLURM_NODELIST']
  453. num_gpus = torch.cuda.device_count()
  454. addr = subprocess.getoutput(
  455. 'scontrol show hostname {} | head -n1'.format(node_list)
  456. )
  457. master_port = os.environ.get('MASTER_PORT', '29488')
  458. os.environ['MASTER_PORT'] = master_port
  459. os.environ['MASTER_ADDR'] = addr
  460. os.environ['WORLD_SIZE'] = str(ntasks)
  461. os.environ['RANK'] = str(proc_id)
  462. os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
  463. os.environ['LOCAL_SIZE'] = str(num_gpus)
  464. args.dist_url = 'env://'
  465. args.world_size = int(ntasks)
  466. args.rank = int(proc_id)
  467. args.gpu = int(proc_id % num_gpus)
  468. print(f'SLURM MODE: proc_id: {proc_id}, ntasks: {ntasks}, node_list: {node_list}, num_gpus:{num_gpus}, addr:{addr}, master port:{master_port}' )
  469. # launched naively with `python main_dino.py`
  470. # we manually add MASTER_ADDR and MASTER_PORT to env variables
  471. elif torch.cuda.is_available():
  472. print('Will run the code on one GPU.')
  473. args.rank, args.gpu, args.world_size = 0, 0, 1
  474. os.environ['MASTER_ADDR'] = '127.0.0.1'
  475. os.environ['MASTER_PORT'] = '29500'
  476. else:
  477. print('Does not support training without GPU.')
  478. sys.exit(1)
  479. dist.init_process_group(
  480. backend="nccl",
  481. init_method=args.dist_url,
  482. world_size=args.world_size,
  483. rank=args.rank,
  484. )
  485. torch.cuda.set_device(args.gpu)
  486. print('| distributed init (rank {}): {}'.format(
  487. args.rank, args.dist_url), flush=True)
  488. dist.barrier()
  489. setup_for_distributed(args.rank == 0)
  490. def main():
  491. args = parse_args()
  492. cfg = get_config(args)
  493. if cfg.train.amp_opt_level != 'O0':
  494. assert amp is not None, 'amp not installed!'
  495. '''
  496. # start faster ref: https://github.com/open-mmlab/mmdetection/pull/7036
  497. mp.set_start_method('fork', force=True)
  498. init_dist('pytorch')
  499. rank, world_size = get_dist_info()
  500. print(f'RANK and WORLD_SIZE in environ: {rank}/{world_size}')
  501. dist.barrier()
  502. '''
  503. init_distributed_mode(args)
  504. rank, world_size = args.rank, args.world_size
  505. set_random_seed(cfg.seed, use_rank_shift=True)
  506. cudnn.benchmark = True
  507. os.makedirs(cfg.output, exist_ok=True)
  508. logger = get_logger(cfg)
  509. # linear scale the learning rate according to total batch size, may not be optimal
  510. linear_scaled_lr = cfg.train.base_lr * cfg.data.train.batch_size * world_size / 4096.0
  511. linear_scaled_warmup_lr = cfg.train.warmup_lr * cfg.data.train.batch_size * world_size / 4096.0
  512. linear_scaled_min_lr = cfg.train.min_lr * cfg.data.train.batch_size * world_size / 4096.0
  513. # gradient accumulation also need to scale the learning rate
  514. if cfg.train.accumulation_steps > 1:
  515. linear_scaled_lr = linear_scaled_lr * cfg.train.accumulation_steps
  516. linear_scaled_warmup_lr = linear_scaled_warmup_lr * cfg.train.accumulation_steps
  517. linear_scaled_min_lr = linear_scaled_min_lr * cfg.train.accumulation_steps
  518. with read_write(cfg):
  519. logger.info(f'Scale base_lr from {cfg.train.base_lr} to {linear_scaled_lr}')
  520. logger.info(f'Scale warmup_lr from {cfg.train.warmup_lr} to {linear_scaled_warmup_lr}')
  521. logger.info(f'Scale min_lr from {cfg.train.min_lr} to {linear_scaled_min_lr}')
  522. cfg.train.base_lr = linear_scaled_lr
  523. cfg.train.warmup_lr = linear_scaled_warmup_lr
  524. cfg.train.min_lr = linear_scaled_min_lr
  525. if dist.get_rank() == 0:
  526. path = os.path.join(cfg.output, 'config.json')
  527. OmegaConf.save(cfg, path)
  528. logger.info(f'Full config saved to {path}')
  529. # log env info
  530. env_info_dict = collect_env()
  531. env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
  532. dash_line = '-' * 60 + '\n'
  533. logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line)
  534. logger.info(f'Git hash: {get_git_hash(digits=7)}')
  535. # print config
  536. logger.info(OmegaConf.to_yaml(cfg))
  537. train(cfg)
  538. dist.barrier()
  539. if __name__ == '__main__':
  540. main()