Просмотр исходного кода

data(cuhkpedes): 更新数据集配置和加载逻辑

- 在 default.yml 中添加 bpe_path 配置项
- 修改 TextDataset 初始化时使用指定的 bpe_path
- 更新 build_dataloader 函数以使用验证集而不是测试集
- 调整 main_group_vit.py 中的数据加载逻辑
Yijun Fu 1 месяц назад
Родитель
Сommit
e2a6f804a8
5 измененных файлов с 94 добавлено и 26 удалено
  1. 6 2
      configs/default.yml
  2. 1 1
      datasets/bases.py
  3. 20 22
      datasets/build.py
  4. 5 1
      main_group_vit.py
  5. 62 0
      tools/cfg2arg.py

+ 6 - 2
configs/default.yml

@@ -6,6 +6,7 @@ data:
   # the differences become academic.
   shuffle_buffer: 10000
   seed: ${train.seed}
+  bpe_path: /home/linkslinks/文档/ai/GroupViT/datasets/bpe_simple_vocab_16e6.txt.gz
   dataset:
     meta:
       gcc3m:
@@ -34,14 +35,17 @@ data:
         prefix: imagenet-val-{000000..000049}.tar
         length: 50000
       cuhkpedes_train:
+        name: CUHK-PEDES
         type: img_txt_pair
         path: local_data/cuhkpedes_shards
-        prefix: cuhkpedes-train-{000000..000255}.tar
+        prefix: cuhkpedes-train-{000000..000004}.tar
         length: 34054
       cuhkpedes_val:
+        raw_path: /home/linkslinks/dataset/
+        name: CUHK-PEDES
         type: img_txt_pair
         path: local_data/cuhkpedes_shards
-        prefix: cuhkpedes-val-{000000..000023}.tar
+        prefix: cuhkpedes-val-{000000..000000}.tar
         length: 3078
     train:
       # - gcc3m

+ 1 - 1
datasets/bases.py

@@ -117,7 +117,7 @@ class TextDataset(Dataset):
         self.captions = captions
         self.text_length = text_length
         self.truncate = truncate
-        self.tokenizer = SimpleTokenizer()
+        self.tokenizer = SimpleTokenizer(bpe_path="/home/linkslinks/文档/ai/GroupViT/datasets/bpe_simple_vocab_16e6.txt.gz")
 
     def __len__(self):
         return len(self.caption_pids)

+ 20 - 22
datasets/build.py

@@ -65,30 +65,28 @@ def collate(batch):
 def build_dataloader(args, tranforms=None):
     logger = logging.getLogger("IRRA.dataset")
 
-    num_workers = args.num_workers
-    dataset = __factory[args.dataset_name](root=args.root_dir)
+    num_workers = args.data.num_workers
+    dataset = __factory[args.data.dataset.meta.cuhkpedes_val.name](root=args.data.dataset.meta.cuhkpedes_val.raw_path)
     num_classes = len(dataset.train_id_container)
 
-    # build dataloader for testing
-    if tranforms:
-        test_transforms = tranforms
-    else:
-        test_transforms = build_transforms(img_size=args.img_size,
-                                            is_train=False)
+    val_transforms = build_transforms(img_size=(args.data.img_aug.img_size * 3, args.data.img_aug.img_size),
+                                    is_train=False)
 
-    ds = dataset.test
-    test_img_set = ImageDataset(ds['image_pids'], ds['img_paths'],
-                                test_transforms)
-    test_txt_set = TextDataset(ds['caption_pids'],
+    # use test set as validate set
+    ds = dataset.val
+    val_img_set = ImageDataset(ds['image_pids'], ds['img_paths'],
+                                val_transforms)
+    val_txt_set = TextDataset(ds['caption_pids'],
                                 ds['captions'],
-                                text_length=args.text_length)
+                                text_length=args.data.text_aug.max_seq_len)
+
+    val_img_loader = DataLoader(val_img_set,
+                                batch_size=args.data.batch_size,
+                                shuffle=False,
+                                num_workers=num_workers)
+    val_txt_loader = DataLoader(val_txt_set,
+                                batch_size=args.data.batch_size,
+                                shuffle=False,
+                                num_workers=num_workers)
 
-    test_img_loader = DataLoader(test_img_set,
-                                    batch_size=args.test_batch_size,
-                                    shuffle=False,
-                                    num_workers=num_workers)
-    test_txt_loader = DataLoader(test_txt_set,
-                                    batch_size=args.test_batch_size,
-                                    shuffle=False,
-                                    num_workers=num_workers)
-    return test_img_loader, test_txt_loader, num_classes
+    return val_img_loader, val_txt_loader, num_classes

+ 5 - 1
main_group_vit.py

@@ -50,6 +50,9 @@ from timm.utils import AverageMeter, accuracy
 from utils import (auto_resume_helper, build_dataset_class_tokens, build_optimizer, build_scheduler, data2cuda,
                    get_config, get_grad_norm, get_logger, load_checkpoint, parse_losses, reduce_tensor, save_checkpoint)
 
+from tools.cfg2arg import cfg2arg
+
+
 try:
     # noinspection PyUnresolvedReferences
     from apex import amp
@@ -109,7 +112,8 @@ def train(cfg):
     print("\n\n\n")
 
     # get image-text pair datasets dataloader
-    train_loader, val_img_loader, val_txt_loader, num_classes = build_dataloader(args)
+    # train_loader, val_img_loader, val_txt_loader, num_classes = build_dataloader(cfg)
+    val_img_loader, val_txt_loader, num_classes = build_dataloader(cfg)
 
     logger = get_logger()
 

+ 62 - 0
tools/cfg2arg.py

@@ -0,0 +1,62 @@
+import argparse
+
+
+def cfg2arg(cfg):
+    # 定义argparse对象
+    parser = argparse.ArgumentParser()
+
+    # 添加参数
+    parser.add_argument('--local_rank', type=int, default=cfg['local_rank'])
+    parser.add_argument('--name', type=str, default=cfg['model_name'])
+    parser.add_argument('--output_dir', type=str, default=cfg['output'])
+    parser.add_argument('--log_period', type=int, default=cfg['print_freq'])
+    parser.add_argument('--eval_period', type=int, default=cfg['evaluate']['eval_freq'])
+    parser.add_argument('--val_dataset', type=str, default=cfg['data']['dataset']['val'][0])
+    parser.add_argument('--resume', type=bool, default=cfg['checkpoint']['auto_resume'])
+    parser.add_argument('--resume_ckpt_file', type=str, default=cfg['checkpoint']['resume'])
+    parser.add_argument('--pretrain_choice', type=str, default='ViT-B/16')  # 这里假设预训练选择是固定的
+    parser.add_argument('--temperature', type=float, default=cfg['model']['contrast_temperature'])
+    parser.add_argument('--img_aug', type=bool, default=cfg['data']['img_aug']['deit_aug'])
+    parser.add_argument('--cmt_depth', type=int, default=4)  # 这里假设cmt_depth是固定的
+    parser.add_argument('--masked_token_rate', type=float, default=0.8)  # 这里假设masked_token_rate是固定的
+    parser.add_argument('--masked_token_unchanged_rate', type=float, default=0.1)  # 这里假设masked_token_unchanged_rate是固定的
+    parser.add_argument('--lr_factor', type=float, default=5.0)  # 这里假设lr_factor是固定的
+    parser.add_argument('--MLM', type=bool, default=True)  # 这里假设MLM是固定的
+    parser.add_argument('--loss_names', type=str, default='sdm+mlm+id')  # 这里假设loss_names是固定的
+    parser.add_argument('--mlm_loss_weight', type=float, default=1.0)  # 这里假设mlm_loss_weight是固定的
+    parser.add_argument('--id_loss_weight', type=float, default=1.0)  # 这里假设id_loss_weight是固定的
+    parser.add_argument('--img_size', type=tuple, default=(cfg['data']['img_aug']['img_size'], cfg['data']['img_aug']['img_size']))
+    parser.add_argument('--stride_size', type=int, default=16)  # 这里假设stride_size是固定的
+    parser.add_argument('--text_length', type=int, default=cfg['data']['text_aug']['max_seq_len'])
+    parser.add_argument('--vocab_size', type=int, default=cfg['model']['text_encoder']['vocab_size'])
+    parser.add_argument('--optimizer', type=str, default=cfg['train']['optimizer']['name'])
+    parser.add_argument('--lr', type=float, default=cfg['train']['base_lr'])
+    parser.add_argument('--bias_lr_factor', type=float, default=2.0)  # 这里假设bias_lr_factor是固定的
+    parser.add_argument('--momentum', type=float, default=0.9)  # 这里假设momentum是固定的
+    parser.add_argument('--weight_decay', type=float, default=cfg['train']['weight_decay'])
+    parser.add_argument('--weight_decay_bias', type=float, default=0.0)  # 这里假设weight_decay_bias是固定的
+    parser.add_argument('--alpha', type=float, default=0.9)  # 这里假设alpha是固定的
+    parser.add_argument('--beta', type=float, default=0.999)  # 这里假设beta是固定的
+    parser.add_argument('--num_epoch', type=int, default=cfg['train']['epochs'])
+    parser.add_argument('--milestones', type=tuple, default=(20, 50))  # 这里假设milestones是固定的
+    parser.add_argument('--gamma', type=float, default=0.1)  # 这里假设gamma是固定的
+    parser.add_argument('--warmup_factor', type=float, default=0.1)  # 这里假设warmup_factor是固定的
+    parser.add_argument('--warmup_epochs', type=int, default=cfg['train']['warmup_epochs'])
+    parser.add_argument('--warmup_method', type=str, default='linear')  # 这里假设warmup_method是固定的
+    parser.add_argument('--lrscheduler', type=str, default=cfg['train']['lr_scheduler']['name'])
+    parser.add_argument('--target_lr', type=float, default=0)  # 这里假设target_lr是固定的
+    parser.add_argument('--power', type=float, default=0.9)  # 这里假设power是固定的
+    parser.add_argument('--dataset_name', type=str, default='CUHK-PEDES')  # 这里假设dataset_name是固定的
+    parser.add_argument('--sampler', type=str, default='random')  # 这里假设sampler是固定的
+    parser.add_argument('--num_instance', type=int, default=4)  # 这里假设num_instance是固定的
+    parser.add_argument('--root_dir', type=str, default='/home/linkslinks/dataset')  # 这里假设root_dir是固定的
+    parser.add_argument('--batch_size', type=int, default=cfg['data']['batch_size'])
+    parser.add_argument('--test_batch_size', type=int, default=512)  # 这里假设test_batch_size是固定的
+    parser.add_argument('--num_workers', type=int, default=cfg['data']['num_workers'])
+    parser.add_argument('--training', type=bool, default=True)  # 这里假设training是固定的
+    parser.add_argument('--distributed', type=bool, default=False)  # 这里假设distributed是固定的
+
+    # 解析参数
+    args = parser.parse_args()
+
+    return args