config.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # -------------------------------------------------------------------------
  10. import os
  11. import os.path as osp
  12. from omegaconf import OmegaConf
  13. def load_config(cfg_file):
  14. cfg = OmegaConf.load(cfg_file)
  15. if '_base_' in cfg:
  16. if isinstance(cfg._base_, str):
  17. base_cfg = OmegaConf.load(osp.join(osp.dirname(cfg_file), cfg._base_))
  18. else:
  19. base_cfg = OmegaConf.merge(OmegaConf.load(f) for f in cfg._base_)
  20. cfg = OmegaConf.merge(base_cfg, cfg)
  21. return cfg
  22. def get_config(args):
  23. cfg = load_config(args.cfg)
  24. OmegaConf.set_struct(cfg, True)
  25. if args.opts is not None:
  26. cfg = OmegaConf.merge(cfg, OmegaConf.from_dotlist(args.opts))
  27. if hasattr(args, 'batch_size') and args.batch_size:
  28. cfg.data.batch_size = args.batch_size
  29. if hasattr(args, 'amp_opt_level') and args.amp_opt_level:
  30. cfg.train.amp_opt_level = args.amp_opt_level
  31. if hasattr(args, 'resume') and args.resume:
  32. cfg.checkpoint.resume = args.resume
  33. if hasattr(args, 'eval') and args.eval:
  34. cfg.evaluate.eval_only = args.eval
  35. if hasattr(args, 'keep') and args.keep:
  36. cfg.checkpoint.max_kept = args.keep
  37. if not cfg.model_name:
  38. cfg.model_name = osp.splitext(osp.basename(args.cfg))[0]
  39. world_size = int(os.environ.get('WORLD_SIZE', 1))
  40. cfg.model_name = cfg.model_name + f'_bs{cfg.data.batch_size}x{world_size}'
  41. if hasattr(args, 'output') and args.output:
  42. cfg.output = args.output
  43. else:
  44. cfg.output = osp.join('output', cfg.model_name)
  45. if hasattr(args, 'tag') and args.tag:
  46. cfg.tag = args.tag
  47. cfg.output = osp.join(cfg.output, cfg.tag)
  48. if hasattr(args, 'wandb') and args.wandb:
  49. cfg.wandb = args.wandb
  50. if hasattr(args, 'vis') and args.vis:
  51. cfg.vis = args.vis
  52. cfg.local_rank = args.local_rank
  53. OmegaConf.set_readonly(cfg, True)
  54. return cfg