formatting.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. #
  4. # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
  5. # property and proprietary rights in and to this software, related
  6. # documentation and any modifications thereto. Any use, reproduction,
  7. # disclosure or distribution of this software and related documentation
  8. # without an express license agreement from NVIDIA CORPORATION is strictly
  9. # prohibited.
  10. #
  11. # Written by Jiarui Xu
  12. # -------------------------------------------------------------------------
  13. import torch
  14. from mmcv.parallel import DataContainer as DC
  15. class ToDataContainer(object):
  16. """Convert results to :obj:`mmcv.DataContainer`"""
  17. def __call__(self, sample):
  18. """Call function to convert data in results to
  19. :obj:`mmcv.DataContainer`.
  20. Args:
  21. sample (torch.Tensor): Input sample.
  22. Returns:
  23. DataContainer
  24. """
  25. if isinstance(sample, int):
  26. sample = torch.tensor(sample)
  27. return DC(sample, stack=True, pad_dims=None)
  28. def __repr__(self):
  29. return self.__class__.__name__