|
a |
|
b/trainers/trainer_utils.py |
|
|
1 |
import numpy as np |
|
|
2 |
|
|
|
3 |
from utils.utils import normalize |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
def get_summary_dict(batch, run, visualization_keys=None, *others): |
|
|
7 |
if visualization_keys is None: |
|
|
8 |
visualization_keys = ['reconstruction', 'L1'] |
|
|
9 |
run = dict(filter(lambda x: x[1] is not None, run.items())) |
|
|
10 |
visuals = np.asarray([ |
|
|
11 |
255 * np.hstack([ |
|
|
12 |
normalize(batch[i]), |
|
|
13 |
*[normalize(run[key][i]) for key in visualization_keys], |
|
|
14 |
*[normalize(element[i]) for element in others] |
|
|
15 |
]) for i in range(len(batch))] |
|
|
16 |
) |
|
|
17 |
scalars = dict(filter(lambda x: not (type(x[1]) == float and x[1] != x[1]) and x[1].ndim == 0, run.items())) |
|
|
18 |
return scalars, visuals |