main_group_vit.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. # -------------------------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. #
  5. # MIT License
  6. #
  7. # Permission is hereby granted, free of charge, to any person obtaining a copy
  8. # of this software and associated documentation files (the "Software"), to deal
  9. # in the Software without restriction, including without limitation the rights
  10. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11. # copies of the Software, and to permit persons to whom the Software is
  12. # furnished to do so, subject to the following conditions:
  13. #
  14. # The above copyright notice and this permission notice shall be included in all
  15. # copies or substantial portions of the Software.
  16. #
  17. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  23. # SOFTWARE
  24. #
  25. # Written by Ze Liu, Zhenda Xie
  26. # Modified by Jiarui Xu
  27. # -------------------------------------------------------------------------
  28. import argparse
  29. import datetime
  30. import os
  31. import os.path as osp
  32. import time
  33. from collections import defaultdict
  34. import torch
  35. import torch.backends.cudnn as cudnn
  36. import torch.distributed as dist
  37. import torch.multiprocessing as mp
  38. from datasets import build_loader, build_text_transform, imagenet_classes
  39. from mmcv.parallel import MMDistributedDataParallel
  40. from mmcv.runner import get_dist_info, init_dist, set_random_seed
  41. from mmcv.utils import collect_env, get_git_hash
  42. from mmseg.apis import multi_gpu_test
  43. from models import build_model
  44. from omegaconf import OmegaConf, read_write
  45. from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
  46. from timm.utils import AverageMeter, accuracy
  47. from utils import (auto_resume_helper, build_dataset_class_tokens, build_optimizer, build_scheduler, data2cuda,
  48. get_config, get_grad_norm, get_logger, load_checkpoint, parse_losses, reduce_tensor, save_checkpoint)
  49. try:
  50. # noinspection PyUnresolvedReferences
  51. from apex import amp
  52. except ImportError:
  53. amp = None
  54. def parse_args():
  55. parser = argparse.ArgumentParser('GroupViT training and evaluation script')
  56. parser.add_argument('--cfg', type=str, required=True, help='path to config file')
  57. parser.add_argument('--opts', help="Modify config options by adding 'KEY=VALUE' list. ", default=None, nargs='+')
  58. # easy config modification
  59. parser.add_argument('--batch-size', type=int, help='batch size for single GPU')
  60. parser.add_argument('--resume', help='resume from checkpoint')
  61. parser.add_argument(
  62. '--amp-opt-level',
  63. type=str,
  64. default='O1',
  65. choices=['O0', 'O1', 'O2'],
  66. help='mixed precision opt level, if O0, no amp is used')
  67. parser.add_argument(
  68. '--output', type=str, help='root of output folder, '
  69. 'the full path is <output>/<model_name>/<tag>')
  70. parser.add_argument('--tag', type=str, help='tag of experiment')
  71. parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
  72. parser.add_argument('--wandb', action='store_true', help='Use W&B to log experiments')
  73. parser.add_argument('--keep', type=int, help='Maximum checkpoint to keep')
  74. # distributed training
  75. parser.add_argument('--local_rank', type=int, required=True, help='local rank for DistributedDataParallel')
  76. args = parser.parse_args()
  77. return args
  78. def train(cfg):
  79. if cfg.wandb and dist.get_rank() == 0:
  80. import wandb
  81. wandb.init(
  82. project='group_vit',
  83. name=osp.join(cfg.model_name, cfg.tag),
  84. dir=cfg.output,
  85. config=OmegaConf.to_container(cfg, resolve=True),
  86. resume=cfg.checkpoint.auto_resume)
  87. else:
  88. wandb = None
  89. # waiting wandb init
  90. dist.barrier()
  91. dataset_train, dataset_val, \
  92. data_loader_train, data_loader_val = build_loader(cfg.data)
  93. data_loader_seg = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
  94. logger = get_logger()
  95. logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
  96. model = build_model(cfg.model)
  97. # load_checkpoint(cfg, model, None, None)
  98. # 冻结所有层
  99. for param in model.parameters():
  100. param.requires_grad = False
  101. # 如果你只想冻结特定的层,可以按照以下方式进行
  102. # 例如,冻结所有的 img_projector 层
  103. for param in model.img_projector.parameters():
  104. param.requires_grad = True
  105. # 如果你只想冻结特定的层,可以按照以下方式进行
  106. # 例如,冻结所有的 text_projector 层
  107. for param in model.text_projector.parameters():
  108. param.requires_grad = True
  109. model.cuda()
  110. logger.info(str(model))
  111. optimizer = build_optimizer(cfg.train, model)
  112. if cfg.train.amp_opt_level != 'O0':
  113. model, optimizer = amp.initialize(model, optimizer, opt_level=cfg.train.amp_opt_level)
  114. model = MMDistributedDataParallel(model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  115. model_without_ddp = model.module
  116. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  117. logger.info(f'number of params: {n_parameters}')
  118. lr_scheduler = build_scheduler(cfg.train, optimizer, len(data_loader_train))
  119. if cfg.checkpoint.auto_resume:
  120. resume_file = auto_resume_helper(cfg.output)
  121. if resume_file:
  122. if cfg.checkpoint.resume:
  123. logger.warning(f'auto-resume changing resume file from {cfg.checkpoint.resume} to {resume_file}')
  124. with read_write(cfg):
  125. cfg.checkpoint.resume = resume_file
  126. logger.info(f'auto resuming from {resume_file}')
  127. else:
  128. logger.info(f'no checkpoint found in {cfg.output}, ignoring auto resume')
  129. max_accuracy = max_miou = 0.0
  130. max_metrics = {'max_accuracy': max_accuracy, 'max_miou': max_miou}
  131. if cfg.checkpoint.resume:
  132. max_metrics = load_checkpoint(cfg, model_without_ddp, optimizer, lr_scheduler)
  133. max_accuracy, max_miou = max_metrics['max_accuracy'], max_metrics['max_miou']
  134. if 'cls' in cfg.evaluate.task:
  135. acc1, acc5, loss = validate_cls(cfg, data_loader_val, model)
  136. logger.info(f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%')
  137. if 'seg' in cfg.evaluate.task:
  138. miou = validate_seg(cfg, data_loader_seg, model)
  139. logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
  140. if cfg.evaluate.eval_only:
  141. return
  142. logger.info('Start training')
  143. start_time = time.time()
  144. for epoch in range(cfg.train.start_epoch, cfg.train.epochs):
  145. loss_train_dict = train_one_epoch(cfg, model, data_loader_train, optimizer, epoch, lr_scheduler)
  146. if dist.get_rank() == 0 and (epoch % cfg.checkpoint.save_freq == 0 or epoch == (cfg.train.epochs - 1)):
  147. save_checkpoint(cfg, epoch, model_without_ddp, {
  148. 'max_accuracy': max_accuracy,
  149. 'max_miou': max_miou
  150. }, optimizer, lr_scheduler)
  151. dist.barrier()
  152. loss_train = loss_train_dict['total_loss']
  153. logger.info(f'Avg loss of the network on the {len(dataset_train)} train images: {loss_train:.2f}')
  154. # evaluate
  155. if (epoch % cfg.evaluate.eval_freq == 0 or epoch == (cfg.train.epochs - 1)):
  156. if 'cls' in cfg.evaluate.task:
  157. acc1, acc5, loss = validate_cls(cfg, data_loader_val, model)
  158. logger.info(f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%')
  159. max_metrics['max_accuracy'] = max(max_metrics['max_accuracy'], acc1)
  160. if cfg.evaluate.cls.save_best and dist.get_rank() == 0 and acc1 > max_accuracy:
  161. save_checkpoint(
  162. cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_acc1')
  163. dist.barrier()
  164. max_accuracy = max_metrics['max_accuracy']
  165. logger.info(f'Max accuracy: {max_accuracy:.2f}%')
  166. if 'seg' in cfg.evaluate.task:
  167. miou = validate_seg(cfg, data_loader_seg, model)
  168. logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
  169. max_metrics['max_miou'] = max(max_metrics['max_miou'], miou)
  170. if cfg.evaluate.seg.save_best and dist.get_rank() == 0 and miou > max_miou:
  171. save_checkpoint(
  172. cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_miou')
  173. dist.barrier()
  174. max_miou = max_metrics['max_miou']
  175. logger.info(f'Max mIoU: {max_miou:.2f}%')
  176. if wandb is not None:
  177. log_stat = {f'epoch/train_{k}': v for k, v in loss_train_dict.items()}
  178. log_stat.update({
  179. 'epoch/val_acc1': acc1,
  180. 'epoch/val_acc5': acc5,
  181. 'epoch/val_loss': loss,
  182. 'epoch/val_miou': miou,
  183. 'epoch/epoch': epoch,
  184. 'epoch/n_parameters': n_parameters
  185. })
  186. wandb.log(log_stat)
  187. total_time = time.time() - start_time
  188. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  189. logger.info('Training time {}'.format(total_time_str))
  190. dist.barrier()
  191. def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler):
  192. logger = get_logger()
  193. dist.barrier()
  194. model.train()
  195. optimizer.zero_grad()
  196. if config.wandb and dist.get_rank() == 0:
  197. import wandb
  198. else:
  199. wandb = None
  200. num_steps = len(data_loader)
  201. batch_time = AverageMeter()
  202. loss_meter = AverageMeter()
  203. norm_meter = AverageMeter()
  204. log_vars_meters = defaultdict(AverageMeter)
  205. start = time.time()
  206. end = time.time()
  207. for idx, samples in enumerate(data_loader):
  208. batch_size = config.data.batch_size
  209. losses = model(**samples)
  210. loss, log_vars = parse_losses(losses)
  211. if config.train.accumulation_steps > 1:
  212. loss = loss / config.train.accumulation_steps
  213. if config.train.amp_opt_level != 'O0':
  214. with amp.scale_loss(loss, optimizer) as scaled_loss:
  215. scaled_loss.backward()
  216. if config.train.clip_grad:
  217. grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.train.clip_grad)
  218. else:
  219. grad_norm = get_grad_norm(amp.master_params(optimizer))
  220. else:
  221. loss.backward()
  222. if config.train.clip_grad:
  223. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad)
  224. else:
  225. grad_norm = get_grad_norm(model.parameters())
  226. if (idx + 1) % config.train.accumulation_steps == 0:
  227. optimizer.step()
  228. optimizer.zero_grad()
  229. lr_scheduler.step_update(epoch * num_steps + idx)
  230. else:
  231. optimizer.zero_grad()
  232. if config.train.amp_opt_level != 'O0':
  233. with amp.scale_loss(loss, optimizer) as scaled_loss:
  234. scaled_loss.backward()
  235. if config.train.clip_grad:
  236. grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.train.clip_grad)
  237. else:
  238. grad_norm = get_grad_norm(amp.master_params(optimizer))
  239. else:
  240. loss.backward()
  241. if config.train.clip_grad:
  242. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad)
  243. else:
  244. grad_norm = get_grad_norm(model.parameters())
  245. optimizer.step()
  246. lr_scheduler.step_update(epoch * num_steps + idx)
  247. torch.cuda.synchronize()
  248. loss_meter.update(loss.item(), batch_size)
  249. for loss_name in log_vars:
  250. log_vars_meters[loss_name].update(log_vars[loss_name], batch_size)
  251. norm_meter.update(grad_norm)
  252. batch_time.update(time.time() - end)
  253. end = time.time()
  254. if idx % config.print_freq == 0:
  255. lr = optimizer.param_groups[0]['lr']
  256. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  257. etas = batch_time.avg * (num_steps - idx)
  258. log_vars_str = '\t'.join(f'{n} {m.val:.4f} ({m.avg:.4f})' for n, m in log_vars_meters.items())
  259. logger.info(f'Train: [{epoch}/{config.train.epochs}][{idx}/{num_steps}]\t'
  260. f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
  261. f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
  262. f'total_loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  263. f'{log_vars_str}\t'
  264. f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
  265. f'mem {memory_used:.0f}MB')
  266. if wandb is not None:
  267. log_stat = {f'iter/train_{n}': m.avg for n, m in log_vars_meters.items()}
  268. log_stat['iter/train_total_loss'] = loss_meter.avg
  269. log_stat['iter/learning_rate'] = lr
  270. wandb.log(log_stat)
  271. epoch_time = time.time() - start
  272. logger.info(f'EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}')
  273. result_dict = dict(total_loss=loss_meter.avg)
  274. for n, m in log_vars_meters.items():
  275. result_dict[n] = m.avg
  276. dist.barrier()
  277. return result_dict
  278. @torch.no_grad()
  279. def validate_cls(config, data_loader, model):
  280. logger = get_logger()
  281. dist.barrier()
  282. criterion = torch.nn.CrossEntropyLoss()
  283. model.eval()
  284. batch_time = AverageMeter()
  285. loss_meter = AverageMeter()
  286. acc1_meter = AverageMeter()
  287. acc5_meter = AverageMeter()
  288. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  289. end = time.time()
  290. logger.info('Building zero shot classifier')
  291. text_embedding = data2cuda(
  292. model.module.build_text_embedding(
  293. build_dataset_class_tokens(text_transform, config.evaluate.cls.template, imagenet_classes)))
  294. logger.info('Zero shot classifier built')
  295. for idx, samples in enumerate(data_loader):
  296. target = samples.pop('target').data[0].cuda()
  297. target = data2cuda(target)
  298. # compute output
  299. output = model(**samples, text=text_embedding)
  300. # measure accuracy and record loss
  301. loss = criterion(output, target)
  302. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  303. acc1 = reduce_tensor(acc1)
  304. acc5 = reduce_tensor(acc5)
  305. loss = reduce_tensor(loss)
  306. loss_meter.update(loss.item(), target.size(0))
  307. acc1_meter.update(acc1.item(), target.size(0))
  308. acc5_meter.update(acc5.item(), target.size(0))
  309. # measure elapsed time
  310. batch_time.update(time.time() - end)
  311. end = time.time()
  312. if idx % config.print_freq == 0:
  313. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  314. logger.info(f'Test: [{idx}/{len(data_loader)}]\t'
  315. f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  316. f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  317. f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
  318. f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
  319. f'Mem {memory_used:.0f}MB')
  320. logger.info('Clearing zero shot classifier')
  321. torch.cuda.empty_cache()
  322. logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
  323. dist.barrier()
  324. return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
  325. @torch.no_grad()
  326. def validate_seg(config, data_loader, model):
  327. logger = get_logger()
  328. dist.barrier()
  329. model.eval()
  330. if hasattr(model, 'module'):
  331. model_without_ddp = model.module
  332. else:
  333. model_without_ddp = model
  334. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  335. seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
  336. mmddp_model = MMDistributedDataParallel(
  337. seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  338. mmddp_model.eval()
  339. results = multi_gpu_test(
  340. model=mmddp_model,
  341. data_loader=data_loader,
  342. tmpdir=None,
  343. gpu_collect=True,
  344. efficient_test=False,
  345. pre_eval=True,
  346. format_only=False)
  347. if dist.get_rank() == 0:
  348. metric = [data_loader.dataset.evaluate(results, metric='mIoU')]
  349. else:
  350. metric = [None]
  351. dist.broadcast_object_list(metric)
  352. miou_result = metric[0]['mIoU'] * 100
  353. torch.cuda.empty_cache()
  354. logger.info(f'Eval Seg mIoU {miou_result:.2f}')
  355. dist.barrier()
  356. return miou_result
  357. def main():
  358. args = parse_args()
  359. cfg = get_config(args)
  360. if cfg.train.amp_opt_level != 'O0':
  361. assert amp is not None, 'amp not installed!'
  362. # start faster ref: https://github.com/open-mmlab/mmdetection/pull/7036
  363. mp.set_start_method('fork', force=True)
  364. init_dist('pytorch')
  365. rank, world_size = get_dist_info()
  366. print(f'RANK and WORLD_SIZE in environ: {rank}/{world_size}')
  367. dist.barrier()
  368. set_random_seed(cfg.seed, use_rank_shift=True)
  369. cudnn.benchmark = True
  370. os.makedirs(cfg.output, exist_ok=True)
  371. logger = get_logger(cfg)
  372. # linear scale the learning rate according to total batch size, may not be optimal
  373. linear_scaled_lr = cfg.train.base_lr * cfg.data.batch_size * world_size / 4096.0
  374. linear_scaled_warmup_lr = cfg.train.warmup_lr * cfg.data.batch_size * world_size / 4096.0
  375. linear_scaled_min_lr = cfg.train.min_lr * cfg.data.batch_size * world_size / 4096.0
  376. # gradient accumulation also need to scale the learning rate
  377. if cfg.train.accumulation_steps > 1:
  378. linear_scaled_lr = linear_scaled_lr * cfg.train.accumulation_steps
  379. linear_scaled_warmup_lr = linear_scaled_warmup_lr * cfg.train.accumulation_steps
  380. linear_scaled_min_lr = linear_scaled_min_lr * cfg.train.accumulation_steps
  381. with read_write(cfg):
  382. logger.info(f'Scale base_lr from {cfg.train.base_lr} to {linear_scaled_lr}')
  383. logger.info(f'Scale warmup_lr from {cfg.train.warmup_lr} to {linear_scaled_warmup_lr}')
  384. logger.info(f'Scale min_lr from {cfg.train.min_lr} to {linear_scaled_min_lr}')
  385. cfg.train.base_lr = linear_scaled_lr
  386. cfg.train.warmup_lr = linear_scaled_warmup_lr
  387. cfg.train.min_lr = linear_scaled_min_lr
  388. if dist.get_rank() == 0:
  389. path = os.path.join(cfg.output, 'config.json')
  390. OmegaConf.save(cfg, path)
  391. logger.info(f'Full config saved to {path}')
  392. # log env info
  393. env_info_dict = collect_env()
  394. env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
  395. dash_line = '-' * 60 + '\n'
  396. logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line)
  397. logger.info(f'Git hash: {get_git_hash(digits=7)}')
  398. # print config
  399. logger.info(OmegaConf.to_yaml(cfg))
  400. train(cfg)
  401. dist.barrier()
  402. if __name__ == '__main__':
  403. main()