1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- # -------------------------------------------------------------------------
- # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
- #
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
- # property and proprietary rights in and to this software, related
- # documentation and any modifications thereto. Any use, reproduction,
- # disclosure or distribution of this software and related documentation
- # without an express license agreement from NVIDIA CORPORATION is strictly
- # prohibited.
- #
- # Written by Jiarui Xu
- # -------------------------------------------------------------------------
- import os
- import os.path as osp
- from omegaconf import OmegaConf
- def load_config(cfg_file):
- cfg = OmegaConf.load(cfg_file)
- if '_base_' in cfg:
- if isinstance(cfg._base_, str):
- base_cfg = OmegaConf.load(osp.join(osp.dirname(cfg_file), cfg._base_))
- else:
- base_cfg = OmegaConf.merge(OmegaConf.load(f) for f in cfg._base_)
- cfg = OmegaConf.merge(base_cfg, cfg)
- return cfg
- def get_config(args):
- cfg = load_config(args.cfg)
- OmegaConf.set_struct(cfg, True)
- if args.opts is not None:
- cfg = OmegaConf.merge(cfg, OmegaConf.from_dotlist(args.opts))
- if hasattr(args, 'batch_size') and args.batch_size:
- cfg.data.train.batch_size = args.batch_size
- if hasattr(args, 'amp_opt_level') and args.amp_opt_level:
- cfg.train.amp_opt_level = args.amp_opt_level
- if hasattr(args, 'resume') and args.resume:
- cfg.checkpoint.resume = args.resume
- if hasattr(args, 'eval') and args.eval:
- cfg.evaluate.eval_only = args.eval
- if hasattr(args, 'keep') and args.keep:
- cfg.checkpoint.max_kept = args.keep
- if not cfg.model_name:
- cfg.model_name = osp.splitext(osp.basename(args.cfg))[0]
- world_size = int(os.environ.get('WORLD_SIZE', 1))
- cfg.model_name = cfg.model_name + f'_bs{cfg.data.train.batch_size}x{world_size}'
- # if hasattr(args, 'output') and args.output:
- # cfg.output = args.output
- if hasattr(args, 'output'):
- cfg.output = osp.join(cfg.output, cfg.model_name)
- else:
- cfg.output = osp.join('output', cfg.model_name)
- if hasattr(args, 'tag') and args.tag:
- cfg.tag = args.tag
- cfg.output = osp.join(cfg.output, cfg.tag)
- if hasattr(args, 'wandb') and args.wandb:
- cfg.wandb = args.wandb
- if hasattr(args, 'vis') and args.vis:
- cfg.vis = args.vis
- ### for demo only ###
- if hasattr(args, 'vocab') and args.vocab:
- cfg.vocab = args.vocab
- if hasattr(args, 'image_folder') and args.image_folder:
- cfg.image_folder = args.image_folder
- if hasattr(args, 'output_folder') and args.output_folder:
- cfg.output_folder = args.output_folder
- cfg.local_rank = args.local_rank
- OmegaConf.set_readonly(cfg, True)
- return cfg
|