Diff of /utils/logger.py [000000] .. [978658]

Switch to unified view

a b/utils/logger.py
1
import os
2
from enum import Enum
3
4
from tensorflow.compat.v1 import variable_scope, placeholder
5
from tensorflow.compat.v1.summary import image, FileWriter, scalar
6
7
8
class Phase(Enum):
9
    TRAIN = 'TRAIN'
10
    VAL = 'VAL'
11
    TEST = 'TEST'
12
13
14
class Logger:
15
    def __init__(self, sess, summary_dir):
16
        self.sess = sess
17
        self.summary_placeholders = {}
18
        self.summary_ops = {}
19
        self.train_summary_writer = FileWriter(os.path.join(summary_dir, Phase.TRAIN.value), self.sess.graph)
20
        self.val_summary_writer = FileWriter(os.path.join(summary_dir, Phase.VAL.value))
21
        self.test_summary_writer = FileWriter(os.path.join(summary_dir, Phase.TEST.value))
22
23
    # it can summarize scalars and images.
24
    def summarize(self, step, phase: Phase = Phase.TRAIN, scope="", summaries_dict=None):
25
        """
26
        :param step: the step of the summary
27
        :param phase: use the train summary writer or the test one
28
        :param scope: variable scope
29
        :param summaries_dict: the dict of the summaries values (tag,value)
30
        :return:
31
        """
32
        if phase == Phase.TRAIN:
33
            summary_writer = self.train_summary_writer
34
        elif phase == Phase.VAL:
35
            summary_writer = self.val_summary_writer
36
        elif phase == Phase.TEST:
37
            summary_writer = self.test_summary_writer
38
        else:
39
            raise ValueError(f'Illegal Argument for summarizer: {phase.value}')
40
41
        with variable_scope(scope):
42
43
            if summaries_dict is not None:
44
                summary_list = []
45
                for tag, value in summaries_dict.items():
46
                    if value is None:
47
                        continue
48
                    if tag not in self.summary_ops:
49
                        if len(value.shape) <= 1:
50
                            self.summary_placeholders[tag] = placeholder('float32', value.shape, name=tag)
51
                            self.summary_ops[tag] = scalar(tag, self.summary_placeholders[tag])
52
                        else:
53
                            self.summary_placeholders[tag] = placeholder('float32', [None] + list(value.shape[1:]), name=tag)
54
                            self.summary_ops[tag] = image(tag, self.summary_placeholders[tag], max_outputs=100)
55
56
                    summary_list.append(self.sess.run(self.summary_ops[tag], {self.summary_placeholders[tag]: value}))
57
58
                for summary in summary_list:
59
                    summary_writer.add_summary(summary, step)
60
                summary_writer.flush()