--- a +++ b/opengait/utils/msg_manager.py @@ -0,0 +1,121 @@ +import time +import torch + +import numpy as np +import torchvision.utils as vutils +import os.path as osp +from time import strftime, localtime + +from torch.utils.tensorboard import SummaryWriter +from .common import is_list, is_tensor, ts2np, mkdir, Odict, NoOp +import logging + + +class MessageManager: + def __init__(self): + self.info_dict = Odict() + self.writer_hparams = ['image', 'scalar'] + self.time = time.time() + + def init_manager(self, save_path, log_to_file, log_iter, iteration=0): + self.iteration = iteration + self.log_iter = log_iter + mkdir(osp.join(save_path, "summary/")) + self.writer = SummaryWriter( + osp.join(save_path, "summary/"), purge_step=self.iteration) + self.init_logger(save_path, log_to_file) + + def init_logger(self, save_path, log_to_file): + # init logger + self.logger = logging.getLogger('opengait') + self.logger.setLevel(logging.INFO) + self.logger.propagate = False + formatter = logging.Formatter( + fmt='[%(asctime)s] [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + if log_to_file: + mkdir(osp.join(save_path, "logs/")) + vlog = logging.FileHandler( + osp.join(save_path, "logs/", strftime('%Y-%m-%d-%H-%M-%S', localtime())+'.txt')) + vlog.setLevel(logging.INFO) + vlog.setFormatter(formatter) + self.logger.addHandler(vlog) + + console = logging.StreamHandler() + console.setFormatter(formatter) + console.setLevel(logging.DEBUG) + self.logger.addHandler(console) + + def append(self, info): + for k, v in info.items(): + v = [v] if not is_list(v) else v + v = [ts2np(_) if is_tensor(_) else _ for _ in v] + info[k] = v + self.info_dict.append(info) + + def flush(self): + self.info_dict.clear() + self.writer.flush() + + def write_to_tensorboard(self, summary): + + for k, v in summary.items(): + module_name = k.split('/')[0] + if module_name not in self.writer_hparams: + self.log_warning( + 'Not Expected --Summary-- type [{}] appear!!!{}'.format(k, self.writer_hparams)) + continue + board_name = k.replace(module_name + "/", '') + writer_module = getattr(self.writer, 'add_' + module_name) + v = v.detach() if is_tensor(v) else v + v = vutils.make_grid( + v, normalize=True, scale_each=True) if 'image' in module_name else v + if module_name == 'scalar': + try: + v = v.mean() + except: + v = v + writer_module(board_name, v, self.iteration) + + def log_training_info(self): + now = time.time() + string = "Iteration {:0>5}, Cost {:.2f}s".format( + self.iteration, now-self.time, end="") + for i, (k, v) in enumerate(self.info_dict.items()): + if 'scalar' not in k: + continue + k = k.replace('scalar/', '').replace('/', '_') + end = "\n" if i == len(self.info_dict)-1 else "" + string += ", {0}={1:.4f}".format(k, np.mean(v), end=end) + self.log_info(string) + self.reset_time() + + def reset_time(self): + self.time = time.time() + + def train_step(self, info, summary): + self.iteration += 1 + self.append(info) + if self.iteration % self.log_iter == 0: + self.log_training_info() + self.flush() + self.write_to_tensorboard(summary) + + def log_debug(self, *args, **kwargs): + self.logger.debug(*args, **kwargs) + + def log_info(self, *args, **kwargs): + self.logger.info(*args, **kwargs) + + def log_warning(self, *args, **kwargs): + self.logger.warning(*args, **kwargs) + + +msg_mgr = MessageManager() +noop = NoOp() + + +def get_msg_mgr(): + if torch.distributed.get_rank() > 0: + return noop + else: + return msg_mgr