|
@@ -0,0 +1,62 @@
|
|
|
+import argparse
|
|
|
+
|
|
|
+
|
|
|
+def cfg2arg(cfg):
|
|
|
+ # 定义argparse对象
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
+
|
|
|
+ # 添加参数
|
|
|
+ parser.add_argument('--local_rank', type=int, default=cfg['local_rank'])
|
|
|
+ parser.add_argument('--name', type=str, default=cfg['model_name'])
|
|
|
+ parser.add_argument('--output_dir', type=str, default=cfg['output'])
|
|
|
+ parser.add_argument('--log_period', type=int, default=cfg['print_freq'])
|
|
|
+ parser.add_argument('--eval_period', type=int, default=cfg['evaluate']['eval_freq'])
|
|
|
+ parser.add_argument('--val_dataset', type=str, default=cfg['data']['dataset']['val'][0])
|
|
|
+ parser.add_argument('--resume', type=bool, default=cfg['checkpoint']['auto_resume'])
|
|
|
+ parser.add_argument('--resume_ckpt_file', type=str, default=cfg['checkpoint']['resume'])
|
|
|
+ parser.add_argument('--pretrain_choice', type=str, default='ViT-B/16') # 这里假设预训练选择是固定的
|
|
|
+ parser.add_argument('--temperature', type=float, default=cfg['model']['contrast_temperature'])
|
|
|
+ parser.add_argument('--img_aug', type=bool, default=cfg['data']['img_aug']['deit_aug'])
|
|
|
+ parser.add_argument('--cmt_depth', type=int, default=4) # 这里假设cmt_depth是固定的
|
|
|
+ parser.add_argument('--masked_token_rate', type=float, default=0.8) # 这里假设masked_token_rate是固定的
|
|
|
+ parser.add_argument('--masked_token_unchanged_rate', type=float, default=0.1) # 这里假设masked_token_unchanged_rate是固定的
|
|
|
+ parser.add_argument('--lr_factor', type=float, default=5.0) # 这里假设lr_factor是固定的
|
|
|
+ parser.add_argument('--MLM', type=bool, default=True) # 这里假设MLM是固定的
|
|
|
+ parser.add_argument('--loss_names', type=str, default='sdm+mlm+id') # 这里假设loss_names是固定的
|
|
|
+ parser.add_argument('--mlm_loss_weight', type=float, default=1.0) # 这里假设mlm_loss_weight是固定的
|
|
|
+ parser.add_argument('--id_loss_weight', type=float, default=1.0) # 这里假设id_loss_weight是固定的
|
|
|
+ parser.add_argument('--img_size', type=tuple, default=(cfg['data']['img_aug']['img_size'], cfg['data']['img_aug']['img_size']))
|
|
|
+ parser.add_argument('--stride_size', type=int, default=16) # 这里假设stride_size是固定的
|
|
|
+ parser.add_argument('--text_length', type=int, default=cfg['data']['text_aug']['max_seq_len'])
|
|
|
+ parser.add_argument('--vocab_size', type=int, default=cfg['model']['text_encoder']['vocab_size'])
|
|
|
+ parser.add_argument('--optimizer', type=str, default=cfg['train']['optimizer']['name'])
|
|
|
+ parser.add_argument('--lr', type=float, default=cfg['train']['base_lr'])
|
|
|
+ parser.add_argument('--bias_lr_factor', type=float, default=2.0) # 这里假设bias_lr_factor是固定的
|
|
|
+ parser.add_argument('--momentum', type=float, default=0.9) # 这里假设momentum是固定的
|
|
|
+ parser.add_argument('--weight_decay', type=float, default=cfg['train']['weight_decay'])
|
|
|
+ parser.add_argument('--weight_decay_bias', type=float, default=0.0) # 这里假设weight_decay_bias是固定的
|
|
|
+ parser.add_argument('--alpha', type=float, default=0.9) # 这里假设alpha是固定的
|
|
|
+ parser.add_argument('--beta', type=float, default=0.999) # 这里假设beta是固定的
|
|
|
+ parser.add_argument('--num_epoch', type=int, default=cfg['train']['epochs'])
|
|
|
+ parser.add_argument('--milestones', type=tuple, default=(20, 50)) # 这里假设milestones是固定的
|
|
|
+ parser.add_argument('--gamma', type=float, default=0.1) # 这里假设gamma是固定的
|
|
|
+ parser.add_argument('--warmup_factor', type=float, default=0.1) # 这里假设warmup_factor是固定的
|
|
|
+ parser.add_argument('--warmup_epochs', type=int, default=cfg['train']['warmup_epochs'])
|
|
|
+ parser.add_argument('--warmup_method', type=str, default='linear') # 这里假设warmup_method是固定的
|
|
|
+ parser.add_argument('--lrscheduler', type=str, default=cfg['train']['lr_scheduler']['name'])
|
|
|
+ parser.add_argument('--target_lr', type=float, default=0) # 这里假设target_lr是固定的
|
|
|
+ parser.add_argument('--power', type=float, default=0.9) # 这里假设power是固定的
|
|
|
+ parser.add_argument('--dataset_name', type=str, default='CUHK-PEDES') # 这里假设dataset_name是固定的
|
|
|
+ parser.add_argument('--sampler', type=str, default='random') # 这里假设sampler是固定的
|
|
|
+ parser.add_argument('--num_instance', type=int, default=4) # 这里假设num_instance是固定的
|
|
|
+ parser.add_argument('--root_dir', type=str, default='/home/linkslinks/dataset') # 这里假设root_dir是固定的
|
|
|
+ parser.add_argument('--batch_size', type=int, default=cfg['data']['batch_size'])
|
|
|
+ parser.add_argument('--test_batch_size', type=int, default=512) # 这里假设test_batch_size是固定的
|
|
|
+ parser.add_argument('--num_workers', type=int, default=cfg['data']['num_workers'])
|
|
|
+ parser.add_argument('--training', type=bool, default=True) # 这里假设training是固定的
|
|
|
+ parser.add_argument('--distributed', type=bool, default=False) # 这里假设distributed是固定的
|
|
|
+
|
|
|
+ # 解析参数
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ return args
|