logger.py 952 B

1234567891011121314151617181920212223242526272829303132
  1. import logging
  2. import os
  3. import sys
  4. import os.path as op
  5. def setup_logger(name, save_dir, if_train, distributed_rank=0):
  6. logger = logging.getLogger(name)
  7. logger.setLevel(logging.DEBUG)
  8. # don't log results for the non-master process
  9. if distributed_rank > 0:
  10. return logger
  11. ch = logging.StreamHandler(stream=sys.stdout)
  12. ch.setLevel(logging.DEBUG)
  13. formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
  14. ch.setFormatter(formatter)
  15. logger.addHandler(ch)
  16. if not op.exists(save_dir):
  17. print(f"{save_dir} is not exists, create given directory")
  18. os.makedirs(save_dir)
  19. if if_train:
  20. fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='w')
  21. else:
  22. fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='a')
  23. fh.setLevel(logging.DEBUG)
  24. fh.setFormatter(formatter)
  25. logger.addHandler(fh)
  26. return logger