|
a |
|
b/utils/logger.py |
|
|
1 |
import os |
|
|
2 |
from enum import Enum |
|
|
3 |
|
|
|
4 |
from tensorflow.compat.v1 import variable_scope, placeholder |
|
|
5 |
from tensorflow.compat.v1.summary import image, FileWriter, scalar |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class Phase(Enum): |
|
|
9 |
TRAIN = 'TRAIN' |
|
|
10 |
VAL = 'VAL' |
|
|
11 |
TEST = 'TEST' |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
class Logger: |
|
|
15 |
def __init__(self, sess, summary_dir): |
|
|
16 |
self.sess = sess |
|
|
17 |
self.summary_placeholders = {} |
|
|
18 |
self.summary_ops = {} |
|
|
19 |
self.train_summary_writer = FileWriter(os.path.join(summary_dir, Phase.TRAIN.value), self.sess.graph) |
|
|
20 |
self.val_summary_writer = FileWriter(os.path.join(summary_dir, Phase.VAL.value)) |
|
|
21 |
self.test_summary_writer = FileWriter(os.path.join(summary_dir, Phase.TEST.value)) |
|
|
22 |
|
|
|
23 |
# it can summarize scalars and images. |
|
|
24 |
def summarize(self, step, phase: Phase = Phase.TRAIN, scope="", summaries_dict=None): |
|
|
25 |
""" |
|
|
26 |
:param step: the step of the summary |
|
|
27 |
:param phase: use the train summary writer or the test one |
|
|
28 |
:param scope: variable scope |
|
|
29 |
:param summaries_dict: the dict of the summaries values (tag,value) |
|
|
30 |
:return: |
|
|
31 |
""" |
|
|
32 |
if phase == Phase.TRAIN: |
|
|
33 |
summary_writer = self.train_summary_writer |
|
|
34 |
elif phase == Phase.VAL: |
|
|
35 |
summary_writer = self.val_summary_writer |
|
|
36 |
elif phase == Phase.TEST: |
|
|
37 |
summary_writer = self.test_summary_writer |
|
|
38 |
else: |
|
|
39 |
raise ValueError(f'Illegal Argument for summarizer: {phase.value}') |
|
|
40 |
|
|
|
41 |
with variable_scope(scope): |
|
|
42 |
|
|
|
43 |
if summaries_dict is not None: |
|
|
44 |
summary_list = [] |
|
|
45 |
for tag, value in summaries_dict.items(): |
|
|
46 |
if value is None: |
|
|
47 |
continue |
|
|
48 |
if tag not in self.summary_ops: |
|
|
49 |
if len(value.shape) <= 1: |
|
|
50 |
self.summary_placeholders[tag] = placeholder('float32', value.shape, name=tag) |
|
|
51 |
self.summary_ops[tag] = scalar(tag, self.summary_placeholders[tag]) |
|
|
52 |
else: |
|
|
53 |
self.summary_placeholders[tag] = placeholder('float32', [None] + list(value.shape[1:]), name=tag) |
|
|
54 |
self.summary_ops[tag] = image(tag, self.summary_placeholders[tag], max_outputs=100) |
|
|
55 |
|
|
|
56 |
summary_list.append(self.sess.run(self.summary_ops[tag], {self.summary_placeholders[tag]: value})) |
|
|
57 |
|
|
|
58 |
for summary in summary_list: |
|
|
59 |
summary_writer.add_summary(summary, step) |
|
|
60 |
summary_writer.flush() |