# ------------------------------------------------------------------------- # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual # property and proprietary rights in and to this software, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this software and related documentation # without an express license agreement from NVIDIA CORPORATION is strictly # prohibited. # # Written by Jiarui Xu # ------------------------------------------------------------------------- import os import os.path as osp from omegaconf import OmegaConf def load_config(cfg_file): cfg = OmegaConf.load(cfg_file) if '_base_' in cfg: if isinstance(cfg._base_, str): base_cfg = OmegaConf.load(osp.join(osp.dirname(cfg_file), cfg._base_)) else: base_cfg = OmegaConf.merge(OmegaConf.load(f) for f in cfg._base_) cfg = OmegaConf.merge(base_cfg, cfg) return cfg def get_config(args): cfg = load_config(args.cfg) OmegaConf.set_struct(cfg, True) if args.opts is not None: cfg = OmegaConf.merge(cfg, OmegaConf.from_dotlist(args.opts)) if hasattr(args, 'batch_size') and args.batch_size: cfg.data.train.batch_size = args.batch_size if hasattr(args, 'amp_opt_level') and args.amp_opt_level: cfg.train.amp_opt_level = args.amp_opt_level if hasattr(args, 'resume') and args.resume: cfg.checkpoint.resume = args.resume if hasattr(args, 'eval') and args.eval: cfg.evaluate.eval_only = args.eval if hasattr(args, 'keep') and args.keep: cfg.checkpoint.max_kept = args.keep if not cfg.model_name: cfg.model_name = osp.splitext(osp.basename(args.cfg))[0] world_size = int(os.environ.get('WORLD_SIZE', 1)) cfg.model_name = cfg.model_name + f'_bs{cfg.data.train.batch_size}x{world_size}' # if hasattr(args, 'output') and args.output: # cfg.output = args.output if hasattr(args, 'output'): cfg.output = osp.join(cfg.output, cfg.model_name) else: cfg.output = osp.join('output', cfg.model_name) if hasattr(args, 'tag') and args.tag: cfg.tag = args.tag cfg.output = osp.join(cfg.output, cfg.tag) if hasattr(args, 'wandb') and args.wandb: cfg.wandb = args.wandb if hasattr(args, 'vis') and args.vis: cfg.vis = args.vis ### for demo only ### if hasattr(args, 'vocab') and args.vocab: cfg.vocab = args.vocab if hasattr(args, 'image_folder') and args.image_folder: cfg.image_folder = args.image_folder if hasattr(args, 'output_folder') and args.output_folder: cfg.output_folder = args.output_folder cfg.local_rank = args.local_rank OmegaConf.set_readonly(cfg, True) return cfg