cfg2arg.py 4.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import argparse
  2. def cfg2arg(cfg):
  3. # 定义argparse对象
  4. parser = argparse.ArgumentParser()
  5. # 添加参数
  6. parser.add_argument('--local_rank', type=int, default=cfg['local_rank'])
  7. parser.add_argument('--name', type=str, default=cfg['model_name'])
  8. parser.add_argument('--output_dir', type=str, default=cfg['output'])
  9. parser.add_argument('--log_period', type=int, default=cfg['print_freq'])
  10. parser.add_argument('--eval_period', type=int, default=cfg['evaluate']['eval_freq'])
  11. parser.add_argument('--val_dataset', type=str, default=cfg['data']['dataset']['val'][0])
  12. parser.add_argument('--resume', type=bool, default=cfg['checkpoint']['auto_resume'])
  13. parser.add_argument('--resume_ckpt_file', type=str, default=cfg['checkpoint']['resume'])
  14. parser.add_argument('--pretrain_choice', type=str, default='ViT-B/16') # 这里假设预训练选择是固定的
  15. parser.add_argument('--temperature', type=float, default=cfg['model']['contrast_temperature'])
  16. parser.add_argument('--img_aug', type=bool, default=cfg['data']['img_aug']['deit_aug'])
  17. parser.add_argument('--cmt_depth', type=int, default=4) # 这里假设cmt_depth是固定的
  18. parser.add_argument('--masked_token_rate', type=float, default=0.8) # 这里假设masked_token_rate是固定的
  19. parser.add_argument('--masked_token_unchanged_rate', type=float, default=0.1) # 这里假设masked_token_unchanged_rate是固定的
  20. parser.add_argument('--lr_factor', type=float, default=5.0) # 这里假设lr_factor是固定的
  21. parser.add_argument('--MLM', type=bool, default=True) # 这里假设MLM是固定的
  22. parser.add_argument('--loss_names', type=str, default='sdm+mlm+id') # 这里假设loss_names是固定的
  23. parser.add_argument('--mlm_loss_weight', type=float, default=1.0) # 这里假设mlm_loss_weight是固定的
  24. parser.add_argument('--id_loss_weight', type=float, default=1.0) # 这里假设id_loss_weight是固定的
  25. parser.add_argument('--img_size', type=tuple, default=(cfg['data']['img_aug']['img_size'], cfg['data']['img_aug']['img_size']))
  26. parser.add_argument('--stride_size', type=int, default=16) # 这里假设stride_size是固定的
  27. parser.add_argument('--text_length', type=int, default=cfg['data']['text_aug']['max_seq_len'])
  28. parser.add_argument('--vocab_size', type=int, default=cfg['model']['text_encoder']['vocab_size'])
  29. parser.add_argument('--optimizer', type=str, default=cfg['train']['optimizer']['name'])
  30. parser.add_argument('--lr', type=float, default=cfg['train']['base_lr'])
  31. parser.add_argument('--bias_lr_factor', type=float, default=2.0) # 这里假设bias_lr_factor是固定的
  32. parser.add_argument('--momentum', type=float, default=0.9) # 这里假设momentum是固定的
  33. parser.add_argument('--weight_decay', type=float, default=cfg['train']['weight_decay'])
  34. parser.add_argument('--weight_decay_bias', type=float, default=0.0) # 这里假设weight_decay_bias是固定的
  35. parser.add_argument('--alpha', type=float, default=0.9) # 这里假设alpha是固定的
  36. parser.add_argument('--beta', type=float, default=0.999) # 这里假设beta是固定的
  37. parser.add_argument('--num_epoch', type=int, default=cfg['train']['epochs'])
  38. parser.add_argument('--milestones', type=tuple, default=(20, 50)) # 这里假设milestones是固定的
  39. parser.add_argument('--gamma', type=float, default=0.1) # 这里假设gamma是固定的
  40. parser.add_argument('--warmup_factor', type=float, default=0.1) # 这里假设warmup_factor是固定的
  41. parser.add_argument('--warmup_epochs', type=int, default=cfg['train']['warmup_epochs'])
  42. parser.add_argument('--warmup_method', type=str, default='linear') # 这里假设warmup_method是固定的
  43. parser.add_argument('--lrscheduler', type=str, default=cfg['train']['lr_scheduler']['name'])
  44. parser.add_argument('--target_lr', type=float, default=0) # 这里假设target_lr是固定的
  45. parser.add_argument('--power', type=float, default=0.9) # 这里假设power是固定的
  46. parser.add_argument('--dataset_name', type=str, default='CUHK-PEDES') # 这里假设dataset_name是固定的
  47. parser.add_argument('--sampler', type=str, default='random') # 这里假设sampler是固定的
  48. parser.add_argument('--num_instance', type=int, default=4) # 这里假设num_instance是固定的
  49. parser.add_argument('--root_dir', type=str, default='/home/linkslinks/dataset') # 这里假设root_dir是固定的
  50. parser.add_argument('--batch_size', type=int, default=cfg['data']['batch_size'])
  51. parser.add_argument('--test_batch_size', type=int, default=512) # 这里假设test_batch_size是固定的
  52. parser.add_argument('--num_workers', type=int, default=cfg['data']['num_workers'])
  53. parser.add_argument('--training', type=bool, default=True) # 这里假设training是固定的
  54. parser.add_argument('--distributed', type=bool, default=False) # 这里假设distributed是固定的
  55. # 解析参数
  56. args = parser.parse_args()
  57. return args