main_group_vit.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. # -------------------------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. #
  5. # MIT License
  6. #
  7. # Permission is hereby granted, free of charge, to any person obtaining a copy
  8. # of this software and associated documentation files (the "Software"), to deal
  9. # in the Software without restriction, including without limitation the rights
  10. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11. # copies of the Software, and to permit persons to whom the Software is
  12. # furnished to do so, subject to the following conditions:
  13. #
  14. # The above copyright notice and this permission notice shall be included in all
  15. # copies or substantial portions of the Software.
  16. #
  17. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  23. # SOFTWARE
  24. #
  25. # Written by Ze Liu, Zhenda Xie
  26. # Modified by Jiarui Xu
  27. # -------------------------------------------------------------------------
  28. import argparse
  29. import datetime
  30. import os
  31. import os.path as osp
  32. import time
  33. from collections import defaultdict
  34. import torch
  35. import torch.backends.cudnn as cudnn
  36. import torch.distributed as dist
  37. import torch.multiprocessing as mp
  38. from datasets import build_loader, build_text_transform, imagenet_classes
  39. from mmcv.parallel import MMDistributedDataParallel
  40. from mmcv.runner import get_dist_info, init_dist, set_random_seed
  41. from mmcv.utils import collect_env, get_git_hash
  42. from mmseg.apis import multi_gpu_test
  43. from models import build_model
  44. from omegaconf import OmegaConf, read_write
  45. from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
  46. from datasets.build import build_dataloader
  47. from timm.utils import AverageMeter, accuracy
  48. from utils import (auto_resume_helper, build_dataset_class_tokens, build_optimizer, build_scheduler, data2cuda,
  49. get_config, get_grad_norm, get_logger, load_checkpoint, parse_losses, reduce_tensor, save_checkpoint)
  50. try:
  51. # noinspection PyUnresolvedReferences
  52. from apex import amp
  53. except ImportError:
  54. amp = None
  55. def parse_args():
  56. parser = argparse.ArgumentParser('GroupViT training and evaluation script')
  57. parser.add_argument('--cfg', type=str, required=True, help='path to config file')
  58. parser.add_argument('--opts', help="Modify config options by adding 'KEY=VALUE' list. ", default=None, nargs='+')
  59. # easy config modification
  60. parser.add_argument('--batch-size', type=int, help='batch size for single GPU')
  61. parser.add_argument('--resume', help='resume from checkpoint')
  62. parser.add_argument(
  63. '--amp-opt-level',
  64. type=str,
  65. default='O1',
  66. choices=['O0', 'O1', 'O2'],
  67. help='mixed precision opt level, if O0, no amp is used')
  68. parser.add_argument(
  69. '--output', type=str, help='root of output folder, '
  70. 'the full path is <output>/<model_name>/<tag>')
  71. parser.add_argument('--tag', type=str, help='tag of experiment')
  72. parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
  73. parser.add_argument('--wandb', action='store_true', help='Use W&B to log experiments')
  74. parser.add_argument('--keep', type=int, help='Maximum checkpoint to keep')
  75. # distributed training
  76. parser.add_argument('--local_rank', type=int, required=True, help='local rank for DistributedDataParallel')
  77. args = parser.parse_args()
  78. return args
  79. def train(cfg):
  80. if cfg.wandb and dist.get_rank() == 0:
  81. import wandb
  82. wandb.init(
  83. project='group_vit',
  84. name=osp.join(cfg.model_name, cfg.tag),
  85. dir=cfg.output,
  86. config=OmegaConf.to_container(cfg, resolve=True),
  87. resume=cfg.checkpoint.auto_resume)
  88. else:
  89. wandb = None
  90. # waiting wandb init
  91. dist.barrier()
  92. dataset_train, dataset_val, \
  93. data_loader_train, data_loader_val = build_loader(cfg.data)
  94. data_loader_seg = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
  95. print("\n\n\n")
  96. print(cfg)
  97. print("\n\n\n")
  98. # get image-text pair datasets dataloader
  99. train_loader, val_img_loader, val_txt_loader, num_classes = build_dataloader(args)
  100. logger = get_logger()
  101. logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
  102. model = build_model(cfg.model)
  103. # load_checkpoint(cfg, model, None, None)
  104. # 冻结所有层
  105. for param in model.parameters():
  106. param.requires_grad = False
  107. # 如果你只想冻结特定的层,可以按照以下方式进行
  108. # 例如,冻结所有的 img_projector 层
  109. for param in model.img_projector.parameters():
  110. param.requires_grad = True
  111. # 如果你只想冻结特定的层,可以按照以下方式进行
  112. # 例如,冻结所有的 text_projector 层
  113. for param in model.text_projector.parameters():
  114. param.requires_grad = True
  115. model.cuda()
  116. logger.info(str(model))
  117. optimizer = build_optimizer(cfg.train, model)
  118. if cfg.train.amp_opt_level != 'O0':
  119. model, optimizer = amp.initialize(model, optimizer, opt_level=cfg.train.amp_opt_level)
  120. model = MMDistributedDataParallel(model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  121. model_without_ddp = model.module
  122. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  123. logger.info(f'number of params: {n_parameters}')
  124. lr_scheduler = build_scheduler(cfg.train, optimizer, len(data_loader_train))
  125. if cfg.checkpoint.auto_resume:
  126. resume_file = auto_resume_helper(cfg.output)
  127. if resume_file:
  128. if cfg.checkpoint.resume:
  129. logger.warning(f'auto-resume changing resume file from {cfg.checkpoint.resume} to {resume_file}')
  130. with read_write(cfg):
  131. cfg.checkpoint.resume = resume_file
  132. logger.info(f'auto resuming from {resume_file}')
  133. else:
  134. logger.info(f'no checkpoint found in {cfg.output}, ignoring auto resume')
  135. max_accuracy = max_miou = max_rank1 = 0.0
  136. max_metrics = {'max_accuracy': max_accuracy, 'max_miou': max_miou, 'max_rank1': max_rank1}
  137. if cfg.checkpoint.resume:
  138. max_metrics = load_checkpoint(cfg, model_without_ddp, optimizer, lr_scheduler)
  139. max_accuracy, max_miou = max_metrics['max_accuracy'], max_metrics['max_miou']
  140. if 'cls' in cfg.evaluate.task:
  141. acc1, acc5, loss = validate_cls(cfg, data_loader_val, model)
  142. logger.info(f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%')
  143. if 'seg' in cfg.evaluate.task:
  144. miou = validate_seg(cfg, data_loader_seg, model)
  145. logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
  146. if 'reid' in cfg.evaluate.task:
  147. mrank1 = validate_reid(cfg, data_loader_reid, model)
  148. logger.info(f'Rank1 of the network on the {len(data_loader_reid)} test images: {mrank1:.2f}%')
  149. if cfg.evaluate.eval_only:
  150. return
  151. logger.info('Start training')
  152. start_time = time.time()
  153. for epoch in range(cfg.train.start_epoch, cfg.train.epochs):
  154. loss_train_dict = train_one_epoch(cfg, model, data_loader_train, optimizer, epoch, lr_scheduler)
  155. if dist.get_rank() == 0 and (epoch % cfg.checkpoint.save_freq == 0 or epoch == (cfg.train.epochs - 1)):
  156. save_checkpoint(cfg, epoch, model_without_ddp, {
  157. 'max_accuracy': max_accuracy,
  158. 'max_miou': max_miou,
  159. 'max_rank1': max_rank1
  160. }, optimizer, lr_scheduler)
  161. dist.barrier()
  162. loss_train = loss_train_dict['total_loss']
  163. logger.info(f'Avg loss of the network on the {len(dataset_train)} train images: {loss_train:.2f}')
  164. # evaluate
  165. if (epoch % cfg.evaluate.eval_freq == 0 or epoch == (cfg.train.epochs - 1)):
  166. if 'cls' in cfg.evaluate.task:
  167. acc1, acc5, loss = validate_cls(cfg, data_loader_val, model)
  168. logger.info(f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%')
  169. max_metrics['max_accuracy'] = max(max_metrics['max_accuracy'], acc1)
  170. if cfg.evaluate.cls.save_best and dist.get_rank() == 0 and acc1 > max_accuracy:
  171. save_checkpoint(
  172. cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_acc1')
  173. dist.barrier()
  174. max_accuracy = max_metrics['max_accuracy']
  175. logger.info(f'Max accuracy: {max_accuracy:.2f}%')
  176. if 'seg' in cfg.evaluate.task:
  177. miou = validate_seg(cfg, data_loader_seg, model)
  178. logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
  179. max_metrics['max_miou'] = max(max_metrics['max_miou'], miou)
  180. if cfg.evaluate.seg.save_best and dist.get_rank() == 0 and miou > max_miou:
  181. save_checkpoint(
  182. cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_miou')
  183. dist.barrier()
  184. max_miou = max_metrics['max_miou']
  185. logger.info(f'Max mIoU: {max_miou:.2f}%')
  186. if 'reid' in cfg.evaluate.task:
  187. mrank1 = validate_reid(cfg, data_loader_reid, model)
  188. logger.info(f'mRank1 of the network on the {len(data_loader_reid)} test images: {mrank1:.2f}%')
  189. max_metrics['max_rank1'] = max(max_metrics['max_rank1'], mrank1)
  190. if cfg.evaluate.reid.save_best and dist.get_rank() == 0 and mrank1 > max_rank1:
  191. save_checkpoint(
  192. cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_rank1')
  193. dist.barrier()
  194. max_rank1 = max_metrics['max_rank1']
  195. logger.info(f'Max mRank1: {max_rank1:.2f}%')
  196. if wandb is not None:
  197. log_stat = {f'epoch/train_{k}': v for k, v in loss_train_dict.items()}
  198. log_stat.update({
  199. 'epoch/val_acc1': acc1,
  200. 'epoch/val_acc5': acc5,
  201. 'epoch/val_loss': loss,
  202. 'epoch/val_miou': miou,
  203. 'epoch/val_rank1': mrank1,
  204. 'epoch/epoch': epoch,
  205. 'epoch/n_parameters': n_parameters
  206. })
  207. wandb.log(log_stat)
  208. total_time = time.time() - start_time
  209. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  210. logger.info('Training time {}'.format(total_time_str))
  211. dist.barrier()
  212. def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler):
  213. logger = get_logger()
  214. dist.barrier()
  215. model.train()
  216. optimizer.zero_grad()
  217. if config.wandb and dist.get_rank() == 0:
  218. import wandb
  219. else:
  220. wandb = None
  221. num_steps = len(data_loader)
  222. batch_time = AverageMeter()
  223. loss_meter = AverageMeter()
  224. norm_meter = AverageMeter()
  225. log_vars_meters = defaultdict(AverageMeter)
  226. start = time.time()
  227. end = time.time()
  228. for idx, samples in enumerate(data_loader):
  229. batch_size = config.data.batch_size
  230. losses = model(**samples)
  231. loss, log_vars = parse_losses(losses)
  232. if config.train.accumulation_steps > 1:
  233. loss = loss / config.train.accumulation_steps
  234. if config.train.amp_opt_level != 'O0':
  235. with amp.scale_loss(loss, optimizer) as scaled_loss:
  236. scaled_loss.backward()
  237. if config.train.clip_grad:
  238. grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.train.clip_grad)
  239. else:
  240. grad_norm = get_grad_norm(amp.master_params(optimizer))
  241. else:
  242. loss.backward()
  243. if config.train.clip_grad:
  244. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad)
  245. else:
  246. grad_norm = get_grad_norm(model.parameters())
  247. if (idx + 1) % config.train.accumulation_steps == 0:
  248. optimizer.step()
  249. optimizer.zero_grad()
  250. lr_scheduler.step_update(epoch * num_steps + idx)
  251. else:
  252. optimizer.zero_grad()
  253. if config.train.amp_opt_level != 'O0':
  254. with amp.scale_loss(loss, optimizer) as scaled_loss:
  255. scaled_loss.backward()
  256. if config.train.clip_grad:
  257. grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.train.clip_grad)
  258. else:
  259. grad_norm = get_grad_norm(amp.master_params(optimizer))
  260. else:
  261. loss.backward()
  262. if config.train.clip_grad:
  263. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad)
  264. else:
  265. grad_norm = get_grad_norm(model.parameters())
  266. optimizer.step()
  267. lr_scheduler.step_update(epoch * num_steps + idx)
  268. torch.cuda.synchronize()
  269. loss_meter.update(loss.item(), batch_size)
  270. for loss_name in log_vars:
  271. log_vars_meters[loss_name].update(log_vars[loss_name], batch_size)
  272. norm_meter.update(grad_norm)
  273. batch_time.update(time.time() - end)
  274. end = time.time()
  275. if idx % config.print_freq == 0:
  276. lr = optimizer.param_groups[0]['lr']
  277. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  278. etas = batch_time.avg * (num_steps - idx)
  279. log_vars_str = '\t'.join(f'{n} {m.val:.4f} ({m.avg:.4f})' for n, m in log_vars_meters.items())
  280. logger.info(f'Train: [{epoch}/{config.train.epochs}][{idx}/{num_steps}]\t'
  281. f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
  282. f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
  283. f'total_loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  284. f'{log_vars_str}\t'
  285. f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
  286. f'mem {memory_used:.0f}MB')
  287. if wandb is not None:
  288. log_stat = {f'iter/train_{n}': m.avg for n, m in log_vars_meters.items()}
  289. log_stat['iter/train_total_loss'] = loss_meter.avg
  290. log_stat['iter/learning_rate'] = lr
  291. wandb.log(log_stat)
  292. epoch_time = time.time() - start
  293. logger.info(f'EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}')
  294. result_dict = dict(total_loss=loss_meter.avg)
  295. for n, m in log_vars_meters.items():
  296. result_dict[n] = m.avg
  297. dist.barrier()
  298. return result_dict
  299. @torch.no_grad()
  300. def validate_cls(config, data_loader, model):
  301. logger = get_logger()
  302. dist.barrier()
  303. criterion = torch.nn.CrossEntropyLoss()
  304. model.eval()
  305. batch_time = AverageMeter()
  306. loss_meter = AverageMeter()
  307. acc1_meter = AverageMeter()
  308. acc5_meter = AverageMeter()
  309. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  310. end = time.time()
  311. logger.info('Building zero shot classifier')
  312. text_embedding = data2cuda(
  313. model.module.build_text_embedding(
  314. build_dataset_class_tokens(text_transform, config.evaluate.cls.template, imagenet_classes)))
  315. logger.info('Zero shot classifier built')
  316. for idx, samples in enumerate(data_loader):
  317. target = samples.pop('target').data[0].cuda()
  318. target = data2cuda(target)
  319. # compute output
  320. output = model(**samples, text=text_embedding)
  321. # measure accuracy and record loss
  322. loss = criterion(output, target)
  323. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  324. acc1 = reduce_tensor(acc1)
  325. acc5 = reduce_tensor(acc5)
  326. loss = reduce_tensor(loss)
  327. loss_meter.update(loss.item(), target.size(0))
  328. acc1_meter.update(acc1.item(), target.size(0))
  329. acc5_meter.update(acc5.item(), target.size(0))
  330. # measure elapsed time
  331. batch_time.update(time.time() - end)
  332. end = time.time()
  333. if idx % config.print_freq == 0:
  334. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  335. logger.info(f'Test: [{idx}/{len(data_loader)}]\t'
  336. f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  337. f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  338. f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
  339. f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
  340. f'Mem {memory_used:.0f}MB')
  341. logger.info('Clearing zero shot classifier')
  342. torch.cuda.empty_cache()
  343. logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
  344. dist.barrier()
  345. return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
  346. @torch.no_grad()
  347. def validate_seg(config, data_loader, model):
  348. logger = get_logger()
  349. dist.barrier()
  350. model.eval()
  351. if hasattr(model, 'module'):
  352. model_without_ddp = model.module
  353. else:
  354. model_without_ddp = model
  355. text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
  356. seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
  357. mmddp_model = MMDistributedDataParallel(
  358. seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
  359. mmddp_model.eval()
  360. results = multi_gpu_test(
  361. model=mmddp_model,
  362. data_loader=data_loader,
  363. tmpdir=None,
  364. gpu_collect=True,
  365. efficient_test=False,
  366. pre_eval=True,
  367. format_only=False)
  368. if dist.get_rank() == 0:
  369. metric = [data_loader.dataset.evaluate(results, metric='mIoU')]
  370. else:
  371. metric = [None]
  372. dist.broadcast_object_list(metric)
  373. miou_result = metric[0]['mIoU'] * 100
  374. torch.cuda.empty_cache()
  375. logger.info(f'Eval Seg mIoU {miou_result:.2f}')
  376. dist.barrier()
  377. return miou_result
  378. @torch.no_grad()
  379. def validate_reid(config, data_loader, model):
  380. print()
  381. def main():
  382. args = parse_args()
  383. cfg = get_config(args)
  384. if cfg.train.amp_opt_level != 'O0':
  385. assert amp is not None, 'amp not installed!'
  386. # start faster ref: https://github.com/open-mmlab/mmdetection/pull/7036
  387. mp.set_start_method('fork', force=True)
  388. init_dist('pytorch')
  389. rank, world_size = get_dist_info()
  390. print(f'RANK and WORLD_SIZE in environ: {rank}/{world_size}')
  391. dist.barrier()
  392. set_random_seed(cfg.seed, use_rank_shift=True)
  393. cudnn.benchmark = True
  394. os.makedirs(cfg.output, exist_ok=True)
  395. logger = get_logger(cfg)
  396. # linear scale the learning rate according to total batch size, may not be optimal
  397. linear_scaled_lr = cfg.train.base_lr * cfg.data.batch_size * world_size / 4096.0
  398. linear_scaled_warmup_lr = cfg.train.warmup_lr * cfg.data.batch_size * world_size / 4096.0
  399. linear_scaled_min_lr = cfg.train.min_lr * cfg.data.batch_size * world_size / 4096.0
  400. # gradient accumulation also need to scale the learning rate
  401. if cfg.train.accumulation_steps > 1:
  402. linear_scaled_lr = linear_scaled_lr * cfg.train.accumulation_steps
  403. linear_scaled_warmup_lr = linear_scaled_warmup_lr * cfg.train.accumulation_steps
  404. linear_scaled_min_lr = linear_scaled_min_lr * cfg.train.accumulation_steps
  405. with read_write(cfg):
  406. logger.info(f'Scale base_lr from {cfg.train.base_lr} to {linear_scaled_lr}')
  407. logger.info(f'Scale warmup_lr from {cfg.train.warmup_lr} to {linear_scaled_warmup_lr}')
  408. logger.info(f'Scale min_lr from {cfg.train.min_lr} to {linear_scaled_min_lr}')
  409. cfg.train.base_lr = linear_scaled_lr
  410. cfg.train.warmup_lr = linear_scaled_warmup_lr
  411. cfg.train.min_lr = linear_scaled_min_lr
  412. if dist.get_rank() == 0:
  413. path = os.path.join(cfg.output, 'config.json')
  414. OmegaConf.save(cfg, path)
  415. logger.info(f'Full config saved to {path}')
  416. # log env info
  417. env_info_dict = collect_env()
  418. env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
  419. dash_line = '-' * 60 + '\n'
  420. logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line)
  421. logger.info(f'Git hash: {get_git_hash(digits=7)}')
  422. # print config
  423. logger.info(OmegaConf.to_yaml(cfg))
  424. train(cfg)
  425. dist.barrier()
  426. if __name__ == '__main__':
  427. main()