config.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. import os
  14. import os.path as osp
  15. from omegaconf import OmegaConf
  16. def load_config(cfg_file):
  17. cfg = OmegaConf.load(cfg_file)
  18. if '_base_' in cfg:
  19. if isinstance(cfg._base_, str):
  20. base_cfg = OmegaConf.load(osp.join(osp.dirname(cfg_file), cfg._base_))
  21. else:
  22. base_cfg = OmegaConf.merge(OmegaConf.load(f) for f in cfg._base_)
  23. cfg = OmegaConf.merge(base_cfg, cfg)
  24. return cfg
  25. def get_config(args):
  26. cfg = load_config(args.cfg)
  27. OmegaConf.set_struct(cfg, True)
  28. if args.opts is not None:
  29. cfg = OmegaConf.merge(cfg, OmegaConf.from_dotlist(args.opts))
  30. if hasattr(args, 'batch_size') and args.batch_size:
  31. cfg.data.train.batch_size = args.batch_size
  32. if hasattr(args, 'amp_opt_level') and args.amp_opt_level:
  33. cfg.train.amp_opt_level = args.amp_opt_level
  34. if hasattr(args, 'resume') and args.resume:
  35. cfg.checkpoint.resume = args.resume
  36. if hasattr(args, 'eval') and args.eval:
  37. cfg.evaluate.eval_only = args.eval
  38. if hasattr(args, 'keep') and args.keep:
  39. cfg.checkpoint.max_kept = args.keep
  40. if not cfg.model_name:
  41. cfg.model_name = osp.splitext(osp.basename(args.cfg))[0]
  42. world_size = int(os.environ.get('WORLD_SIZE', 1))
  43. cfg.model_name = cfg.model_name + f'_bs{cfg.data.train.batch_size}x{world_size}'
  44. # if hasattr(args, 'output') and args.output:
  45. # cfg.output = args.output
  46. if hasattr(args, 'output'):
  47. cfg.output = osp.join(cfg.output, cfg.model_name)
  48. else:
  49. cfg.output = osp.join('output', cfg.model_name)
  50. if hasattr(args, 'tag') and args.tag:
  51. cfg.tag = args.tag
  52. cfg.output = osp.join(cfg.output, cfg.tag)
  53. if hasattr(args, 'wandb') and args.wandb:
  54. cfg.wandb = args.wandb
  55. if hasattr(args, 'vis') and args.vis:
  56. cfg.vis = args.vis
  57. ### for demo only ###
  58. if hasattr(args, 'vocab') and args.vocab:
  59. cfg.vocab = args.vocab
  60. if hasattr(args, 'image_folder') and args.image_folder:
  61. cfg.image_folder = args.image_folder
  62. if hasattr(args, 'output_folder') and args.output_folder:
  63. cfg.output_folder = args.output_folder
  64. cfg.local_rank = args.local_rank
  65. OmegaConf.set_readonly(cfg, True)
  66. return cfg