a b/opengait/utils/msg_manager.py
1
import time
2
import torch
3
4
import numpy as np
5
import torchvision.utils as vutils
6
import os.path as osp
7
from time import strftime, localtime
8
9
from torch.utils.tensorboard import SummaryWriter
10
from .common import is_list, is_tensor, ts2np, mkdir, Odict, NoOp
11
import logging
12
13
14
class MessageManager:
15
    def __init__(self):
16
        self.info_dict = Odict()
17
        self.writer_hparams = ['image', 'scalar']
18
        self.time = time.time()
19
20
    def init_manager(self, save_path, log_to_file, log_iter, iteration=0):
21
        self.iteration = iteration
22
        self.log_iter = log_iter
23
        mkdir(osp.join(save_path, "summary/"))
24
        self.writer = SummaryWriter(
25
            osp.join(save_path, "summary/"), purge_step=self.iteration)
26
        self.init_logger(save_path, log_to_file)
27
28
    def init_logger(self, save_path, log_to_file):
29
        # init logger
30
        self.logger = logging.getLogger('opengait')
31
        self.logger.setLevel(logging.INFO)
32
        self.logger.propagate = False
33
        formatter = logging.Formatter(
34
            fmt='[%(asctime)s] [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
35
        if log_to_file:
36
            mkdir(osp.join(save_path, "logs/"))
37
            vlog = logging.FileHandler(
38
                osp.join(save_path, "logs/", strftime('%Y-%m-%d-%H-%M-%S', localtime())+'.txt'))
39
            vlog.setLevel(logging.INFO)
40
            vlog.setFormatter(formatter)
41
            self.logger.addHandler(vlog)
42
43
        console = logging.StreamHandler()
44
        console.setFormatter(formatter)
45
        console.setLevel(logging.DEBUG)
46
        self.logger.addHandler(console)
47
48
    def append(self, info):
49
        for k, v in info.items():
50
            v = [v] if not is_list(v) else v
51
            v = [ts2np(_) if is_tensor(_) else _ for _ in v]
52
            info[k] = v
53
        self.info_dict.append(info)
54
55
    def flush(self):
56
        self.info_dict.clear()
57
        self.writer.flush()
58
59
    def write_to_tensorboard(self, summary):
60
61
        for k, v in summary.items():
62
            module_name = k.split('/')[0]
63
            if module_name not in self.writer_hparams:
64
                self.log_warning(
65
                    'Not Expected --Summary-- type [{}] appear!!!{}'.format(k, self.writer_hparams))
66
                continue
67
            board_name = k.replace(module_name + "/", '')
68
            writer_module = getattr(self.writer, 'add_' + module_name)
69
            v = v.detach() if is_tensor(v) else v
70
            v = vutils.make_grid(
71
                v, normalize=True, scale_each=True) if 'image' in module_name else v
72
            if module_name == 'scalar':
73
                try:
74
                    v = v.mean()
75
                except:
76
                    v = v
77
            writer_module(board_name, v, self.iteration)
78
79
    def log_training_info(self):
80
        now = time.time()
81
        string = "Iteration {:0>5}, Cost {:.2f}s".format(
82
            self.iteration, now-self.time, end="")
83
        for i, (k, v) in enumerate(self.info_dict.items()):
84
            if 'scalar' not in k:
85
                continue
86
            k = k.replace('scalar/', '').replace('/', '_')
87
            end = "\n" if i == len(self.info_dict)-1 else ""
88
            string += ", {0}={1:.4f}".format(k, np.mean(v), end=end)
89
        self.log_info(string)
90
        self.reset_time()
91
92
    def reset_time(self):
93
        self.time = time.time()
94
95
    def train_step(self, info, summary):
96
        self.iteration += 1
97
        self.append(info)
98
        if self.iteration % self.log_iter == 0:
99
            self.log_training_info()
100
            self.flush()
101
            self.write_to_tensorboard(summary)
102
103
    def log_debug(self, *args, **kwargs):
104
        self.logger.debug(*args, **kwargs)
105
106
    def log_info(self, *args, **kwargs):
107
        self.logger.info(*args, **kwargs)
108
109
    def log_warning(self, *args, **kwargs):
110
        self.logger.warning(*args, **kwargs)
111
112
113
msg_mgr = MessageManager()
114
noop = NoOp()
115
116
117
def get_msg_mgr():
118
    if torch.distributed.get_rank() > 0:
119
        return noop
120
    else:
121
        return msg_mgr