a b/trainers/trainer_utils.py
1
import numpy as np
2
3
from utils.utils import normalize
4
5
6
def get_summary_dict(batch, run, visualization_keys=None, *others):
7
    if visualization_keys is None:
8
        visualization_keys = ['reconstruction', 'L1']
9
    run = dict(filter(lambda x: x[1] is not None, run.items()))
10
    visuals = np.asarray([
11
        255 * np.hstack([
12
            normalize(batch[i]),
13
            *[normalize(run[key][i]) for key in visualization_keys],
14
            *[normalize(element[i]) for element in others]
15
        ]) for i in range(len(batch))]
16
    )
17
    scalars = dict(filter(lambda x: not (type(x[1]) == float and x[1] != x[1]) and x[1].ndim == 0, run.items()))
18
    return scalars, visuals