options.py 4.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import argparse
  2. def get_args():
  3. parser = argparse.ArgumentParser(description="IRRA Args")
  4. ######################## general settings ########################
  5. parser.add_argument("--local_rank", default=0, type=int)
  6. parser.add_argument("--name", default="baseline", help="experiment name to save")
  7. parser.add_argument("--output_dir", default="logs")
  8. parser.add_argument("--log_period", default=100)
  9. parser.add_argument("--eval_period", default=1)
  10. parser.add_argument("--val_dataset", default="test") # use val set when evaluate, if test use test set
  11. parser.add_argument("--resume", default=False, action='store_true')
  12. parser.add_argument("--resume_ckpt_file", default="", help='resume from ...')
  13. ######################## model general settings ########################
  14. parser.add_argument("--pretrain_choice", default='ViT-B/16') # whether use pretrained model
  15. parser.add_argument("--temperature", type=float, default=0.02, help="initial temperature value, if 0, don't use temperature")
  16. parser.add_argument("--img_aug", default=False, action='store_true')
  17. ## cross modal transfomer setting
  18. parser.add_argument("--cmt_depth", type=int, default=4, help="cross modal transformer self attn layers")
  19. parser.add_argument("--masked_token_rate", type=float, default=0.8, help="masked token rate for mlm task")
  20. parser.add_argument("--masked_token_unchanged_rate", type=float, default=0.1, help="masked token unchanged rate")
  21. parser.add_argument("--lr_factor", type=float, default=5.0, help="lr factor for random init self implement module")
  22. parser.add_argument("--MLM", default=False, action='store_true', help="whether to use Mask Language Modeling dataset")
  23. ######################## loss settings ########################
  24. parser.add_argument("--loss_names", default='sdm+id+mlm', help="which loss to use ['mlm', 'cmpm', 'id', 'itc', 'sdm']")
  25. parser.add_argument("--mlm_loss_weight", type=float, default=1.0, help="mlm loss weight")
  26. parser.add_argument("--id_loss_weight", type=float, default=1.0, help="id loss weight")
  27. ######################## vison trainsformer settings ########################
  28. parser.add_argument("--img_size", type=tuple, default=(384, 128))
  29. parser.add_argument("--stride_size", type=int, default=16)
  30. ######################## text transformer settings ########################
  31. parser.add_argument("--text_length", type=int, default=77)
  32. parser.add_argument("--vocab_size", type=int, default=49408)
  33. ######################## solver ########################
  34. parser.add_argument("--optimizer", type=str, default="Adam", help="[SGD, Adam, Adamw]")
  35. parser.add_argument("--lr", type=float, default=1e-5)
  36. parser.add_argument("--bias_lr_factor", type=float, default=2.)
  37. parser.add_argument("--momentum", type=float, default=0.9)
  38. parser.add_argument("--weight_decay", type=float, default=4e-5)
  39. parser.add_argument("--weight_decay_bias", type=float, default=0.)
  40. parser.add_argument("--alpha", type=float, default=0.9)
  41. parser.add_argument("--beta", type=float, default=0.999)
  42. ######################## scheduler ########################
  43. parser.add_argument("--num_epoch", type=int, default=60)
  44. parser.add_argument("--milestones", type=int, nargs='+', default=(20, 50))
  45. parser.add_argument("--gamma", type=float, default=0.1)
  46. parser.add_argument("--warmup_factor", type=float, default=0.1)
  47. parser.add_argument("--warmup_epochs", type=int, default=5)
  48. parser.add_argument("--warmup_method", type=str, default="linear")
  49. parser.add_argument("--lrscheduler", type=str, default="cosine")
  50. parser.add_argument("--target_lr", type=float, default=0)
  51. parser.add_argument("--power", type=float, default=0.9)
  52. ######################## dataset ########################
  53. parser.add_argument("--dataset_name", default="CUHK-PEDES", help="[CUHK-PEDES, ICFG-PEDES, RSTPReid]")
  54. parser.add_argument("--sampler", default="random", help="choose sampler from [idtentity, random]")
  55. parser.add_argument("--num_instance", type=int, default=4)
  56. parser.add_argument("--root_dir", default="./data")
  57. parser.add_argument("--batch_size", type=int, default=128)
  58. parser.add_argument("--test_batch_size", type=int, default=512)
  59. parser.add_argument("--num_workers", type=int, default=8)
  60. parser.add_argument("--test", dest='training', default=True, action='store_false')
  61. args = parser.parse_args()
  62. return args