formatting.py 1009 B

123456789101112131415161718192021222324252627282930313233
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
  3. #
  4. # This work is made available under the Nvidia Source Code License.
  5. # To view a copy of this license, visit
  6. # https://github.com/NVlabs/GroupViT/blob/main/LICENSE
  7. #
  8. # Written by Jiarui Xu
  9. # -------------------------------------------------------------------------
  10. import torch
  11. from mmcv.parallel import DataContainer as DC
  12. class ToDataContainer(object):
  13. """Convert results to :obj:`mmcv.DataContainer`"""
  14. def __call__(self, sample):
  15. """Call function to convert data in results to
  16. :obj:`mmcv.DataContainer`.
  17. Args:
  18. sample (torch.Tensor): Input sample.
  19. Returns:
  20. DataContainer
  21. """
  22. if isinstance(sample, int):
  23. sample = torch.tensor(sample)
  24. return DC(sample, stack=True, pad_dims=None)
  25. def __repr__(self):
  26. return self.__class__.__name__