Diff of /common/utils.py [000000] .. [f804b3]

Switch to unified view

a b/common/utils.py
1
import numpy as np
2
3
4
5
def log_images (x,  y_pred,  y_true=None, channel=1):
6
    images = []
7
    x_np = x[:, channel].cpu().numpy()
8
    y_true_np = y_true[:, 0].cpu().numpy()
9
    y_pred_np = y_pred[:, 0].cpu().numpy()
10
    for i in range(x_np.shape[0]):
11
        image = gray2rgb(np.squeeze(x_np[i]))
12
        image = outline(image, y_pred_np[i], color=[255, 0, 0])
13
        image = outline(image, y_true_np[i], color=[0, 255, 0])
14
        images.append(image)
15
    return images
16
17
18
19
def gray2rgb(image):
20
    w, h = image.shape
21
    image += np.abs(np.min(image))
22
    image_max = np.abs(np.max(image))
23
    if image_max > 0:
24
        image /= image_max
25
    ret = np.empty((w, h, 3), dtype=np.uint8)
26
    ret[:, :, 2] = ret[:, :, 1] = ret[:, :, 0] = image * 255
27
    return ret
28
29
30
def outline(image, mask, color):
31
    mask = np.round(mask)
32
    yy, xx = np.nonzero(mask)
33
    for y, x in zip(yy, xx):
34
        if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < 1.0:
35
            image[max(0, y) : y + 1, max(0, x) : x + 1] = color
36
    return image