|
@@ -45,10 +45,15 @@ from mmseg.apis import multi_gpu_test
|
|
from models import build_model
|
|
from models import build_model
|
|
from omegaconf import OmegaConf, read_write
|
|
from omegaconf import OmegaConf, read_write
|
|
from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
|
|
from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
|
|
|
|
+from datasets.build import build_dataloader
|
|
from timm.utils import AverageMeter, accuracy
|
|
from timm.utils import AverageMeter, accuracy
|
|
from utils import (auto_resume_helper, build_dataset_class_tokens, build_optimizer, build_scheduler, data2cuda,
|
|
from utils import (auto_resume_helper, build_dataset_class_tokens, build_optimizer, build_scheduler, data2cuda,
|
|
get_config, get_grad_norm, get_logger, load_checkpoint, parse_losses, reduce_tensor, save_checkpoint)
|
|
get_config, get_grad_norm, get_logger, load_checkpoint, parse_losses, reduce_tensor, save_checkpoint)
|
|
|
|
|
|
|
|
+from tools.cfg2arg import cfg2arg
|
|
|
|
+from utils.metrics import Evaluator
|
|
|
|
+
|
|
|
|
+
|
|
try:
|
|
try:
|
|
# noinspection PyUnresolvedReferences
|
|
# noinspection PyUnresolvedReferences
|
|
from apex import amp
|
|
from apex import amp
|
|
@@ -103,6 +108,14 @@ def train(cfg):
|
|
data_loader_train, data_loader_val = build_loader(cfg.data)
|
|
data_loader_train, data_loader_val = build_loader(cfg.data)
|
|
data_loader_seg = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
|
|
data_loader_seg = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
|
|
|
|
|
|
|
|
+ print("\n\n\n")
|
|
|
|
+ print(cfg)
|
|
|
|
+ print("\n\n\n")
|
|
|
|
+
|
|
|
|
+ # get image-text pair datasets dataloader
|
|
|
|
+ # train_loader, val_img_loader, val_txt_loader, num_classes = build_dataloader(cfg)
|
|
|
|
+ val_img_loader, val_txt_loader, num_classes = build_dataloader(cfg)
|
|
|
|
+
|
|
logger = get_logger()
|
|
logger = get_logger()
|
|
|
|
|
|
logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
|
|
logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
|
|
@@ -148,8 +161,8 @@ def train(cfg):
|
|
else:
|
|
else:
|
|
logger.info(f'no checkpoint found in {cfg.output}, ignoring auto resume')
|
|
logger.info(f'no checkpoint found in {cfg.output}, ignoring auto resume')
|
|
|
|
|
|
- max_accuracy = max_miou = 0.0
|
|
|
|
- max_metrics = {'max_accuracy': max_accuracy, 'max_miou': max_miou}
|
|
|
|
|
|
+ max_accuracy = max_miou = max_rank1 = 0.0
|
|
|
|
+ max_metrics = {'max_accuracy': max_accuracy, 'max_miou': max_miou, 'max_rank1': max_rank1}
|
|
|
|
|
|
if cfg.checkpoint.resume:
|
|
if cfg.checkpoint.resume:
|
|
max_metrics = load_checkpoint(cfg, model_without_ddp, optimizer, lr_scheduler)
|
|
max_metrics = load_checkpoint(cfg, model_without_ddp, optimizer, lr_scheduler)
|
|
@@ -160,6 +173,11 @@ def train(cfg):
|
|
if 'seg' in cfg.evaluate.task:
|
|
if 'seg' in cfg.evaluate.task:
|
|
miou = validate_seg(cfg, data_loader_seg, model)
|
|
miou = validate_seg(cfg, data_loader_seg, model)
|
|
logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
|
|
logger.info(f'mIoU of the network on the {len(data_loader_seg.dataset)} test images: {miou:.2f}%')
|
|
|
|
+ if 'reid' in cfg.evaluate.task:
|
|
|
|
+ # mrank1 = validate_reid(cfg, data_loader_reid, model)
|
|
|
|
+ mrank1 = validate_reid(cfg, val_img_loader, val_txt_loader, model)
|
|
|
|
+ # logger.info(f'Rank1 of the network on the {len(data_loader_reid)} test images: {mrank1:.2f}%')
|
|
|
|
+ logger.info(f'Rank1 of the network on the {len(val_img_loader)} test images: {mrank1:.2f}%')
|
|
if cfg.evaluate.eval_only:
|
|
if cfg.evaluate.eval_only:
|
|
return
|
|
return
|
|
|
|
|
|
@@ -170,7 +188,8 @@ def train(cfg):
|
|
if dist.get_rank() == 0 and (epoch % cfg.checkpoint.save_freq == 0 or epoch == (cfg.train.epochs - 1)):
|
|
if dist.get_rank() == 0 and (epoch % cfg.checkpoint.save_freq == 0 or epoch == (cfg.train.epochs - 1)):
|
|
save_checkpoint(cfg, epoch, model_without_ddp, {
|
|
save_checkpoint(cfg, epoch, model_without_ddp, {
|
|
'max_accuracy': max_accuracy,
|
|
'max_accuracy': max_accuracy,
|
|
- 'max_miou': max_miou
|
|
|
|
|
|
+ 'max_miou': max_miou,
|
|
|
|
+ 'max_rank1': max_rank1
|
|
}, optimizer, lr_scheduler)
|
|
}, optimizer, lr_scheduler)
|
|
dist.barrier()
|
|
dist.barrier()
|
|
loss_train = loss_train_dict['total_loss']
|
|
loss_train = loss_train_dict['total_loss']
|
|
@@ -198,6 +217,16 @@ def train(cfg):
|
|
dist.barrier()
|
|
dist.barrier()
|
|
max_miou = max_metrics['max_miou']
|
|
max_miou = max_metrics['max_miou']
|
|
logger.info(f'Max mIoU: {max_miou:.2f}%')
|
|
logger.info(f'Max mIoU: {max_miou:.2f}%')
|
|
|
|
+ if 'reid' in cfg.evaluate.task:
|
|
|
|
+ mrank1 = validate_reid(cfg, val_img_loader, val_txt_loader, model)
|
|
|
|
+ logger.info(f'mRank1 of the network on the {len(val_img_loader)} test images: {mrank1:.2f}%')
|
|
|
|
+ max_metrics['max_rank1'] = max(max_metrics['max_rank1'], mrank1)
|
|
|
|
+ if cfg.evaluate.reid.save_best and dist.get_rank() == 0 and mrank1 > max_rank1:
|
|
|
|
+ save_checkpoint(
|
|
|
|
+ cfg, epoch, model_without_ddp, max_metrics, optimizer, lr_scheduler, suffix='best_rank1')
|
|
|
|
+ dist.barrier()
|
|
|
|
+ max_rank1 = max_metrics['max_rank1']
|
|
|
|
+ logger.info(f'Max mRank1: {max_rank1:.2f}%')
|
|
|
|
|
|
if wandb is not None:
|
|
if wandb is not None:
|
|
log_stat = {f'epoch/train_{k}': v for k, v in loss_train_dict.items()}
|
|
log_stat = {f'epoch/train_{k}': v for k, v in loss_train_dict.items()}
|
|
@@ -206,6 +235,7 @@ def train(cfg):
|
|
'epoch/val_acc5': acc5,
|
|
'epoch/val_acc5': acc5,
|
|
'epoch/val_loss': loss,
|
|
'epoch/val_loss': loss,
|
|
'epoch/val_miou': miou,
|
|
'epoch/val_miou': miou,
|
|
|
|
+ 'epoch/val_rank1': mrank1,
|
|
'epoch/epoch': epoch,
|
|
'epoch/epoch': epoch,
|
|
'epoch/n_parameters': n_parameters
|
|
'epoch/n_parameters': n_parameters
|
|
})
|
|
})
|
|
@@ -413,6 +443,44 @@ def validate_seg(config, data_loader, model):
|
|
return miou_result
|
|
return miou_result
|
|
|
|
|
|
|
|
|
|
|
|
+@torch.no_grad()
|
|
|
|
+def validate_reid(cfg, img_loader, txt_loader, model):
|
|
|
|
+ logger = get_logger()
|
|
|
|
+ dist.barrier()
|
|
|
|
+ # model.eval()
|
|
|
|
+ evaluator = Evaluator(img_loader, txt_loader)
|
|
|
|
+
|
|
|
|
+ if hasattr(model, 'module'):
|
|
|
|
+ model_without_ddp = model.module
|
|
|
|
+ else:
|
|
|
|
+ model_without_ddp = model
|
|
|
|
+
|
|
|
|
+ # reid_model = build_reid_inference(model_without_ddp, img_loader, txt_loader, cfg.evaluate.reid)
|
|
|
|
+
|
|
|
|
+ # mmddp_model = MMDistributedDataParallel(
|
|
|
|
+ # reid_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
|
|
|
|
+ rank1 = evaluator.eval(model_without_ddp.eval())
|
|
|
|
+ # results = multi_gpu_test(
|
|
|
|
+ # model=mmddp_model,
|
|
|
|
+ # data_loader=img_loader,
|
|
|
|
+ # tmpdir=None,
|
|
|
|
+ # gpu_collect=True,
|
|
|
|
+ # efficient_test=False,
|
|
|
|
+ # pre_eval=True,
|
|
|
|
+ # format_only=False)
|
|
|
|
+
|
|
|
|
+ # if dist.get_rank() == 0:
|
|
|
|
+ # metric = [img_loader.dataset.evaluate(results, metric='Rank-1')]
|
|
|
|
+ # else:
|
|
|
|
+ # metric = [None]
|
|
|
|
+ # dist.broadcast_object_list(metric)
|
|
|
|
+ # rank1_result = metric[0]['Rank-1'] * 100
|
|
|
|
+
|
|
|
|
+ torch.cuda.empty_cache()
|
|
|
|
+
|
|
|
|
+ return rank1
|
|
|
|
+
|
|
|
|
+
|
|
def main():
|
|
def main():
|
|
args = parse_args()
|
|
args = parse_args()
|
|
cfg = get_config(args)
|
|
cfg = get_config(args)
|