Switch to unified view

a b/Section 3 Simulate DIMSE/src/utils/utils.py
1
"""
2
Various utility methods in this module
3
"""
4
import numpy as np
5
import matplotlib.pyplot as plt
6
import matplotlib as mpl
7
import torch
8
from PIL import Image
9
10
# Tell Matplotlib to not try and use interactive backend
11
mpl.use("agg")
12
13
def mpl_image_grid(images):
14
    """
15
    Create an image grid from an array of images. Show up to 16 images in one figure
16
17
    Arguments:
18
        image {Torch tensor} -- NxWxH array of images
19
20
    Returns:
21
        Matplotlib figure
22
    """
23
    # Create a figure to contain the plot.
24
    n = min(images.shape[0], 16) # no more than 16 thumbnails
25
    rows = 4
26
    cols = (n // 4) + (1 if (n % 4) != 0 else 0)
27
    figure = plt.figure(figsize=(2*rows, 2*cols))
28
    plt.subplots_adjust(0, 0, 1, 1, 0.001, 0.001)
29
    for i in range(n):
30
        # Start next subplot.
31
        plt.subplot(cols, rows, i + 1)
32
        plt.xticks([])
33
        plt.yticks([])
34
        plt.grid(False)
35
        if images.shape[1] == 3:
36
            # this is specifically for 3 softmax'd classes with 0 being bg
37
            # We are building a probability map from our three classes using 
38
            # fractional probabilities contained in the mask
39
            vol = images[i].detach().numpy()
40
            img = [[[(1-vol[0,x,y])*vol[1,x,y], (1-vol[0,x,y])*vol[2,x,y], 0] \
41
                            for y in range(vol.shape[2])] \
42
                            for x in range(vol.shape[1])]
43
            plt.imshow(img)
44
        else: # plotting only 1st channel
45
            plt.imshow((images[i, 0]*255).int(), cmap= "gray")
46
47
    return figure
48
49
def log_to_tensorboard(writer, loss, data, target, prediction_softmax, prediction, counter):
50
    """Logs data to Tensorboard
51
52
    Arguments:
53
        writer {SummaryWriter} -- PyTorch Tensorboard wrapper to use for logging
54
        loss {float} -- loss
55
        data {tensor} -- image data
56
        target {tensor} -- ground truth label
57
        prediction_softmax {tensor} -- softmax'd prediction
58
        prediction {tensor} -- raw prediction (to be used in argmax)
59
        counter {int} -- batch and epoch counter
60
    """
61
    writer.add_scalar("Loss",\
62
                    loss, counter)
63
    writer.add_figure("Image Data",\
64
        mpl_image_grid(data.float().cpu()), global_step=counter)
65
    writer.add_figure("Mask",\
66
        mpl_image_grid(target.float().cpu()), global_step=counter)
67
    writer.add_figure("Probability map",\
68
        mpl_image_grid(prediction_softmax.cpu()), global_step=counter)
69
    writer.add_figure("Prediction",\
70
        mpl_image_grid(torch.argmax(prediction.cpu(), dim=1, keepdim=True)), global_step=counter)
71
72
def save_numpy_as_image(arr, path):
73
    """
74
    This saves image (2D array) as a file using matplotlib
75
76
    Arguments:
77
        arr {array} -- 2D array of pixels
78
        path {string} -- path to file
79
    """
80
    plt.imshow(arr, cmap="gray") #Needs to be in row,col order
81
    plt.savefig(path)
82
83
def med_reshape(image, new_shape):
84
    """
85
    This function reshapes 3D data to new dimension padding with zeros
86
    and leaving the content in the top-left corner
87
88
    Arguments:
89
        image {array} -- 3D array of pixel data
90
        new_shape {3-tuple} -- expected output shape
91
92
    Returns:
93
        3D array of desired shape, padded with zeroes
94
    """
95
96
    reshaped_image = np.zeros(new_shape)
97
98
    # TASK: write your original image into the reshaped image
99
    reshaped_image[0:image.shape[0],0:image.shape[1],0:image.shape[2]] += image 
100
101
    return reshaped_image