|
@@ -173,8 +173,10 @@ def train(cfg):
|
|
|
miou = validate_seg(cfg, data_loader_seg, model)
|
|
|
logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
|
|
|
if 'reid' in cfg.evaluate.task:
|
|
|
- mrank1 = validate_reid(cfg, data_loader_reid, model)
|
|
|
- logger.info(f'Rank1 of the network on the {len(data_loader_reid)} test images: {mrank1:.2f}%')
|
|
|
+ # mrank1 = validate_reid(cfg, data_loader_reid, model)
|
|
|
+ mrank1 = validate_reid(cfg, val_img_loader, val_txt_loader, model)
|
|
|
+ # logger.info(f'Rank1 of the network on the {len(data_loader_reid)} test images: {mrank1:.2f}%')
|
|
|
+ logger.info(f'Rank1 of the network on the {len(val_img_loader)} test images: {mrank1:.2f}%')
|
|
|
if cfg.evaluate.eval_only:
|
|
|
return
|
|
|
|
|
@@ -215,8 +217,8 @@ def train(cfg):
|
|
|
max_miou = max_metrics['max_miou']
|
|
|
logger.info(f'Max mIoU: {max_miou:.2f}%')
|
|
|
if 'reid' in cfg.evaluate.task:
|
|
|
- mrank1 = validate_reid(cfg, data_loader_reid, model)
|
|
|
- logger.info(f'mRank1 of the network on the {len(data_loader_reid)} test images: {mrank1:.2f}%')
|
|
|
+ mrank1 = validate_reid(cfg, val_img_loader, val_txt_loader, model)
|
|
|
+ logger.info(f'mRank1 of the network on the {len(val_img_loader)} test images: {mrank1:.2f}%')
|
|
|
max_metrics['max_rank1'] = max(max_metrics['max_rank1'], mrank1)
|
|
|
if cfg.evaluate.reid.save_best and dist.get_rank() == 0 and mrank1 > max_rank1:
|
|
|
save_checkpoint(
|
|
@@ -441,8 +443,40 @@ def validate_seg(config, data_loader, model):
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
-def validate_reid(config, data_loader, model):
|
|
|
- print()
|
|
|
+def validate_reid(cfg, img_loader, txt_loader, model):
|
|
|
+ logger = get_logger()
|
|
|
+ dist.barrier()
|
|
|
+ model.eval()
|
|
|
+
|
|
|
+ if hasattr(model, 'module'):
|
|
|
+ model_without_ddp = model.module
|
|
|
+ else:
|
|
|
+ model_without_ddp = model
|
|
|
+
|
|
|
+ reid_model = build_reid_inference(model_without_ddp, img_loader, txt_loader, cfg.evaluate.reid)
|
|
|
+
|
|
|
+ mmddp_model = MMDistributedDataParallel(
|
|
|
+ reid_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
|
|
|
+ mmddp_model.eval()
|
|
|
+ results = multi_gpu_test(
|
|
|
+ model=mmddp_model,
|
|
|
+ data_loader=img_loader,
|
|
|
+ tmpdir=None,
|
|
|
+ gpu_collect=True,
|
|
|
+ efficient_test=False,
|
|
|
+ pre_eval=True,
|
|
|
+ format_only=False)
|
|
|
+
|
|
|
+ if dist.get_rank() == 0:
|
|
|
+ metric = [img_loader.dataset.evaluate(results, metric='Rank-1')]
|
|
|
+ else:
|
|
|
+ metric = [None]
|
|
|
+ dist.broadcast_object_list(metric)
|
|
|
+ rank1_result = metric[0]['Rank-1'] * 100
|
|
|
+
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+
|
|
|
+ return rank1_result
|
|
|
|
|
|
|
|
|
def main():
|