preprocessing.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import random
  2. import math
  3. class RandomErasing(object):
  4. """ Randomly selects a rectangle region in an image and erases its pixels.
  5. 'Random Erasing Data Augmentation' by Zhong et al.
  6. See https://arxiv.org/pdf/1708.04896.pdf
  7. Args:
  8. probability: The probability that the Random Erasing operation will be performed.
  9. sl: Minimum proportion of erased area against input image.
  10. sh: Maximum proportion of erased area against input image.
  11. r1: Minimum aspect ratio of erased area.
  12. mean: Erasing value.
  13. """
  14. def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
  15. self.probability = probability
  16. self.mean = mean
  17. self.sl = sl
  18. self.sh = sh
  19. self.r1 = r1
  20. def __call__(self, img):
  21. if random.uniform(0, 1) >= self.probability:
  22. return img
  23. for attempt in range(100):
  24. area = img.size()[1] * img.size()[2]
  25. target_area = random.uniform(self.sl, self.sh) * area
  26. aspect_ratio = random.uniform(self.r1, 1 / self.r1)
  27. h = int(round(math.sqrt(target_area * aspect_ratio)))
  28. w = int(round(math.sqrt(target_area / aspect_ratio)))
  29. if w < img.size()[2] and h < img.size()[1]:
  30. x1 = random.randint(0, img.size()[1] - h)
  31. y1 = random.randint(0, img.size()[2] - w)
  32. if img.size()[0] == 3:
  33. img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
  34. img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
  35. img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
  36. else:
  37. img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
  38. return img
  39. return img