--- a +++ b/common/utils.py @@ -0,0 +1,36 @@ +import numpy as np + + + +def log_images (x, y_pred, y_true=None, channel=1): + images = [] + x_np = x[:, channel].cpu().numpy() + y_true_np = y_true[:, 0].cpu().numpy() + y_pred_np = y_pred[:, 0].cpu().numpy() + for i in range(x_np.shape[0]): + image = gray2rgb(np.squeeze(x_np[i])) + image = outline(image, y_pred_np[i], color=[255, 0, 0]) + image = outline(image, y_true_np[i], color=[0, 255, 0]) + images.append(image) + return images + + + +def gray2rgb(image): + w, h = image.shape + image += np.abs(np.min(image)) + image_max = np.abs(np.max(image)) + if image_max > 0: + image /= image_max + ret = np.empty((w, h, 3), dtype=np.uint8) + ret[:, :, 2] = ret[:, :, 1] = ret[:, :, 0] = image * 255 + return ret + + +def outline(image, mask, color): + mask = np.round(mask) + yy, xx = np.nonzero(mask) + for y, x in zip(yy, xx): + if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < 1.0: + image[max(0, y) : y + 1, max(0, x) : x + 1] = color + return image