123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
- import logging
- import os
- from collections import OrderedDict
- import torch
- class Checkpointer:
- def __init__(
- self,
- model,
- optimizer=None,
- scheduler=None,
- save_dir="",
- save_to_disk=None,
- logger=None,
- ):
- self.model = model
- self.optimizer = optimizer
- self.scheduler = scheduler
- self.save_dir = save_dir
- self.save_to_disk = save_to_disk
- if logger is None:
- logger = logging.getLogger(__name__)
- self.logger = logger
- def save(self, name, **kwargs):
- if not self.save_dir:
- return
- if not self.save_to_disk:
- return
- data = {}
- data["model"] = self.model.state_dict()
- if self.optimizer is not None:
- data["optimizer"] = self.optimizer.state_dict()
- if self.scheduler is not None:
- data["scheduler"] = self.scheduler.state_dict()
- data.update(kwargs)
- save_file = os.path.join(self.save_dir, "{}.pth".format(name))
- self.logger.info("Saving checkpoint to {}".format(save_file))
- torch.save(data, save_file)
- def load(self, f=None):
- if not f:
- # no checkpoint could be found
- self.logger.info("No checkpoint found.")
- return {}
- self.logger.info("Loading checkpoint from {}".format(f))
- checkpoint = self._load_file(f)
- self._load_model(checkpoint)
- def resume(self, f=None):
- if not f:
- # no checkpoint could be found
- self.logger.info("No checkpoint found.")
- raise IOError(f"No Checkpoint file found on {f}")
- self.logger.info("Loading checkpoint from {}".format(f))
- checkpoint = self._load_file(f)
- self._load_model(checkpoint)
- if "optimizer" in checkpoint and self.optimizer:
- self.logger.info("Loading optimizer from {}".format(f))
- self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
- if "scheduler" in checkpoint and self.scheduler:
- self.logger.info("Loading scheduler from {}".format(f))
- self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
- # return any further checkpoint data
- return checkpoint
- def _load_file(self, f):
- return torch.load(f, map_location=torch.device("cpu"))
- def _load_model(self, checkpoint, except_keys=None):
- load_state_dict(self.model, checkpoint.pop("model"), except_keys)
- def check_key(key, except_keys):
- if except_keys is None:
- return False
- else:
- for except_key in except_keys:
- if except_key in key:
- return True
- return False
- def align_and_update_state_dicts(model_state_dict, loaded_state_dict, except_keys=None):
- current_keys = sorted(list(model_state_dict.keys()))
- loaded_keys = sorted(list(loaded_state_dict.keys()))
- # get a matrix of string matches, where each (i, j) entry correspond to the size of the
- # loaded_key string, if it matches
- match_matrix = [
- len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys
- ]
- match_matrix = torch.as_tensor(match_matrix).view(
- len(current_keys), len(loaded_keys)
- )
- max_match_size, idxs = match_matrix.max(1)
- # remove indices that correspond to no-match
- idxs[max_match_size == 0] = -1
- # used for logging
- max_size = max([len(key) for key in current_keys]) if current_keys else 1
- max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1
- log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
- logger = logging.getLogger("PersonSearch.checkpoint")
- for idx_new, idx_old in enumerate(idxs.tolist()):
- if idx_old == -1:
- continue
- key = current_keys[idx_new]
- key_old = loaded_keys[idx_old]
- if check_key(key, except_keys):
- continue
- model_state_dict[key] = loaded_state_dict[key_old]
- logger.info(
- log_str_template.format(
- key,
- max_size,
- key_old,
- max_size_loaded,
- tuple(loaded_state_dict[key_old].shape),
- )
- )
- def strip_prefix_if_present(state_dict, prefix):
- keys = sorted(state_dict.keys())
- if not all(key.startswith(prefix) for key in keys):
- return state_dict
- stripped_state_dict = OrderedDict()
- for key, value in state_dict.items():
- stripped_state_dict[key.replace(prefix, "")] = value
- return stripped_state_dict
- def load_state_dict(model, loaded_state_dict, except_keys=None):
- model_state_dict = model.state_dict()
- # if the state_dict comes from a model that was wrapped in a
- # DataParallel or DistributedDataParallel during serialization,
- # remove the "module" prefix before performing the matching
- loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.")
- align_and_update_state_dicts(model_state_dict, loaded_state_dict, except_keys)
- # use strict loading
- model.load_state_dict(model_state_dict)
|