123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- # -------------------------------------------------------------------------
- # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
- #
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
- # property and proprietary rights in and to this software, related
- # documentation and any modifications thereto. Any use, reproduction,
- # disclosure or distribution of this software and related documentation
- # without an express license agreement from NVIDIA CORPORATION is strictly
- # prohibited.
- #
- # Written by Jiarui Xu
- # -------------------------------------------------------------------------
- # Modified by Jilan Xu
- # -------------------------------------------------------------------------
- import collections.abc
- from collections import OrderedDict
- import json
- import cv2
- import numpy as np
- from PIL import Image
- from ipdb import set_trace
- import scipy
- import torch
- import torch.distributed as dist
- from datasets import template_meta
- def reduce_tensor(tensor):
- rt = tensor.clone()
- dist.all_reduce(rt, op=dist.ReduceOp.SUM)
- rt /= dist.get_world_size()
- return rt
- def momentum_update(model_on, model_off, coeff):
- for params_on, params_off in zip(model_on.parameters(), model_off.parameters()):
- params_off.data = coeff * params_off.data + (1 - coeff) * params_on.data
- def get_grad_norm(parameters, norm_type=2):
- if isinstance(parameters, torch.Tensor):
- parameters = [parameters]
- parameters = list(filter(lambda p: p.grad is not None, parameters))
- norm_type = float(norm_type)
- total_norm = 0
- for p in parameters:
- param_norm = p.grad.data.norm(norm_type)
- total_norm += param_norm.item()**norm_type
- total_norm = total_norm**(1. / norm_type)
- return total_norm
- def get_batch_size(data):
- if isinstance(data, torch.Tensor):
- return data.size(0)
- elif isinstance(data, collections.abc.Mapping):
- return get_batch_size(data[next(iter(data))])
- elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
- # check to make sure that the elements in batch have consistent size
- it = iter(data)
- return get_batch_size(next(it))
- raise TypeError
- def data2cuda(data):
- if isinstance(data, torch.Tensor):
- batch = data.cuda(non_blocking=True)
- return batch
- elif isinstance(data, collections.abc.Mapping):
- return {key: data2cuda(data[key]) for key in data}
- elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
- return [data2cuda(d) for d in data]
- else:
- raise TypeError
- def parse_losses(losses):
- log_vars = OrderedDict()
- for loss_name, loss_value in losses.items():
- if isinstance(loss_value, torch.Tensor):
- log_vars[loss_name] = loss_value.mean()
- elif isinstance(loss_value, list):
- log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
- else:
- raise TypeError(f'{loss_name} is not a tensor or list of tensors')
- loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
- return loss, log_vars
- def build_dataset_class_tokens(text_transform, template_set, classnames):
- tokens = []
- templates = template_meta[template_set]
- for classname in classnames:
- # format with class
- tokens.append(torch.stack([text_transform(template.format(classname)) for template in templates]))
- # [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
- tokens = torch.stack(tokens)
- return tokens
- def build_dataset_class_lists(template_set, classnames):
- tokens = []
- templates = template_meta[template_set]
- for classname in classnames:
- # format with class
- for template in templates:
- tokens.append(template.format(classname))
- # [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
- # tokens = torch.stack(tokens)
- return tokens
-
- def cdist_(x, metric='euclidean'):
- assert len(x.shape) == 3
- if metric != 'JS':
- x_ = torch.split(x, 1, dim=0) # tuple
- return np.mean(tuple(map(lambda a: scipy.spatial.distance.cdist(a.squeeze(), a.squeeze(), metric).mean(), x_)))
- else:
- softmax = torch.nn.Softmax(dim=1)
- x_ = torch.split(softmax(x), 1, dim=0) # tuple
- return np.mean(tuple(map(lambda a: scipy.spatial.distance.cdist(a.squeeze(), a.squeeze(), metric).mean(), x_)))
- pass
- pass
|