checkpoint.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
  2. import logging
  3. import os
  4. from collections import OrderedDict
  5. import torch
  6. class Checkpointer:
  7. def __init__(
  8. self,
  9. model,
  10. optimizer=None,
  11. scheduler=None,
  12. save_dir="",
  13. save_to_disk=None,
  14. logger=None,
  15. ):
  16. self.model = model
  17. self.optimizer = optimizer
  18. self.scheduler = scheduler
  19. self.save_dir = save_dir
  20. self.save_to_disk = save_to_disk
  21. if logger is None:
  22. logger = logging.getLogger(__name__)
  23. self.logger = logger
  24. def save(self, name, **kwargs):
  25. if not self.save_dir:
  26. return
  27. if not self.save_to_disk:
  28. return
  29. data = {}
  30. data["model"] = self.model.state_dict()
  31. if self.optimizer is not None:
  32. data["optimizer"] = self.optimizer.state_dict()
  33. if self.scheduler is not None:
  34. data["scheduler"] = self.scheduler.state_dict()
  35. data.update(kwargs)
  36. save_file = os.path.join(self.save_dir, "{}.pth".format(name))
  37. self.logger.info("Saving checkpoint to {}".format(save_file))
  38. torch.save(data, save_file)
  39. def load(self, f=None):
  40. if not f:
  41. # no checkpoint could be found
  42. self.logger.info("No checkpoint found.")
  43. return {}
  44. self.logger.info("Loading checkpoint from {}".format(f))
  45. checkpoint = self._load_file(f)
  46. self._load_model(checkpoint)
  47. def resume(self, f=None):
  48. if not f:
  49. # no checkpoint could be found
  50. self.logger.info("No checkpoint found.")
  51. raise IOError(f"No Checkpoint file found on {f}")
  52. self.logger.info("Loading checkpoint from {}".format(f))
  53. checkpoint = self._load_file(f)
  54. self._load_model(checkpoint)
  55. if "optimizer" in checkpoint and self.optimizer:
  56. self.logger.info("Loading optimizer from {}".format(f))
  57. self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
  58. if "scheduler" in checkpoint and self.scheduler:
  59. self.logger.info("Loading scheduler from {}".format(f))
  60. self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
  61. # return any further checkpoint data
  62. return checkpoint
  63. def _load_file(self, f):
  64. return torch.load(f, map_location=torch.device("cpu"))
  65. def _load_model(self, checkpoint, except_keys=None):
  66. load_state_dict(self.model, checkpoint.pop("model"), except_keys)
  67. def check_key(key, except_keys):
  68. if except_keys is None:
  69. return False
  70. else:
  71. for except_key in except_keys:
  72. if except_key in key:
  73. return True
  74. return False
  75. def align_and_update_state_dicts(model_state_dict, loaded_state_dict, except_keys=None):
  76. current_keys = sorted(list(model_state_dict.keys()))
  77. loaded_keys = sorted(list(loaded_state_dict.keys()))
  78. # get a matrix of string matches, where each (i, j) entry correspond to the size of the
  79. # loaded_key string, if it matches
  80. match_matrix = [
  81. len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys
  82. ]
  83. match_matrix = torch.as_tensor(match_matrix).view(
  84. len(current_keys), len(loaded_keys)
  85. )
  86. max_match_size, idxs = match_matrix.max(1)
  87. # remove indices that correspond to no-match
  88. idxs[max_match_size == 0] = -1
  89. # used for logging
  90. max_size = max([len(key) for key in current_keys]) if current_keys else 1
  91. max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1
  92. log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
  93. logger = logging.getLogger("PersonSearch.checkpoint")
  94. for idx_new, idx_old in enumerate(idxs.tolist()):
  95. if idx_old == -1:
  96. continue
  97. key = current_keys[idx_new]
  98. key_old = loaded_keys[idx_old]
  99. if check_key(key, except_keys):
  100. continue
  101. model_state_dict[key] = loaded_state_dict[key_old]
  102. logger.info(
  103. log_str_template.format(
  104. key,
  105. max_size,
  106. key_old,
  107. max_size_loaded,
  108. tuple(loaded_state_dict[key_old].shape),
  109. )
  110. )
  111. def strip_prefix_if_present(state_dict, prefix):
  112. keys = sorted(state_dict.keys())
  113. if not all(key.startswith(prefix) for key in keys):
  114. return state_dict
  115. stripped_state_dict = OrderedDict()
  116. for key, value in state_dict.items():
  117. stripped_state_dict[key.replace(prefix, "")] = value
  118. return stripped_state_dict
  119. def load_state_dict(model, loaded_state_dict, except_keys=None):
  120. model_state_dict = model.state_dict()
  121. # if the state_dict comes from a model that was wrapped in a
  122. # DataParallel or DistributedDataParallel during serialization,
  123. # remove the "module" prefix before performing the matching
  124. loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.")
  125. align_and_update_state_dicts(model_state_dict, loaded_state_dict, except_keys)
  126. # use strict loading
  127. model.load_state_dict(model_state_dict)