a b/tool/Code/utilities/visualization_misc.py
1
# Copyright 2019 Population Health Sciences and Image Analysis, German Center for Neurodegenerative Diseases(DZNE)
2
#
3
#    Licensed under the Apache License, Version 2.0 (the "License");
4
#    you may not use this file except in compliance with the License.
5
#    You may obtain a copy of the License at
6
#
7
#        http://www.apache.org/licenses/LICENSE-2.0
8
#
9
#    Unless required by applicable law or agreed to in writing, software
10
#    distributed under the License is distributed on an "AS IS" BASIS,
11
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
#    See the License for the specific language governing permissions and
13
#    limitations under the License.
14
15
import matplotlib
16
matplotlib.use('agg')
17
import matplotlib.pyplot as plt
18
import numpy as np
19
import matplotlib.gridspec as gridspec
20
import matplotlib.cm as cm
21
from matplotlib.colors import LinearSegmentedColormap
22
import itertools
23
24
def get_colors(inp, colormap, vmin=None, vmax=None):
25
    """generate the normalize rgb values for matplolib
26
"""
27
    norm = plt.Normalize(vmin, vmax)
28
    return colormap(norm(inp))
29
30
31
def multiview_plotting(data,labels,control_point, savepath,classes=5,alpha=0.5,nbviews=3,plot_labels=True,plot_control_point=True):
32
    """Plot data and label in different views
33
    Args:
34
        data: Original 3D volume
35
        labels: Original labels for the 3d Volume
36
        control_point: select the center point where the different views are going to be created
37
        savepath:path where the image is going to be safe
38
        classes: number of classes in the labeles
39
        alpha: transparency of the labels on the original data
40
        nbviews: 1 only axial view,2 axial and frontal, 3 the three views
41
        plot_labels: True plot labels, False only plot data
42
43
44
    Returns:
45
        An images with the diffent views and the corresponding label
46
"""
47
    # Create the colormap for the labels
48
    dz=np.arange(classes)
49
    colors = get_colors(dz, plt.cm.jet)
50
    #replace first color for black
51
    colors[0, 0:3] = [0, 0, 0]
52
    my_cm=LinearSegmentedColormap.from_list('mylist',colors,classes)
53
    plt.ioff()
54
    if plot_labels:
55
        grid_size = [2, nbviews]
56
    else:
57
        grid_size = [1, nbviews]
58
    #fig = plt.figure(dpi=600)
59
    fig = plt.gcf()
60
61
62
    outer_grid = gridspec.GridSpec(grid_size[0], grid_size[1], wspace=0.05, hspace=0.0005)
63
    index = 0
64
    for i in range(nbviews):
65
        if i == 0 :
66
            ax = plt.subplot(outer_grid[i])
67
            ax.imshow(data[control_point[0], :, :], cmap=cm.gray)
68
            if plot_control_point:
69
                ax.scatter(y=control_point[1], x=control_point[2], c='r', s=2)
70
            ax.set_xticklabels([])
71
            ax.set_yticklabels([])
72
            if plot_labels:
73
                ax = plt.subplot(outer_grid[i+nbviews])
74
                ax.imshow(data[control_point[0], :, :], cmap=cm.gray)
75
                ax.imshow(labels[control_point[0], :, :],vmin=0,vmax=classes, cmap=my_cm,alpha=alpha)
76
                if plot_control_point:
77
                    ax.scatter(y=control_point[1], x=control_point[2], c='r', s=2)
78
                ax.set_xticklabels([])
79
                ax.set_yticklabels([])
80
81
        elif i == 1:
82
            ax = plt.subplot(outer_grid[i])
83
            ax.imshow(data[:, control_point[1], :], cmap=cm.gray,aspect=(data.shape[1]/data.shape[0]))
84
            if plot_control_point:
85
                ax.scatter(y=control_point[0], x=control_point[2], c='r', s=2)
86
            ax.set_xticklabels([])
87
            ax.set_yticklabels([])
88
            if plot_labels:
89
                ax = plt.subplot(outer_grid[i+nbviews])
90
                ax.imshow(data[:, control_point[1], :], cmap=cm.gray,aspect=(data.shape[1]/data.shape[0]))
91
                ax.imshow(labels[:, control_point[1], :],vmin=0,vmax=classes, cmap=my_cm,alpha=alpha,aspect=(data.shape[1]/data.shape[0]))
92
                if plot_control_point:
93
                    ax.scatter(y=control_point[0], x=control_point[2], c='r', s=2)
94
                ax.set_xticklabels([])
95
                ax.set_yticklabels([])
96
97
        elif i == 2:
98
            img=np.zeros((data.shape[0],data.shape[2]))
99
            diff_spacing=int((data.shape[2]-data.shape[1])/2)
100
            img[:,diff_spacing:data.shape[2]-diff_spacing]=data[:, :, control_point[2]]
101
102
            ax = plt.subplot(outer_grid[i])
103
            ax.imshow(img, cmap=cm.gray,aspect=(data.shape[1]/data.shape[0]))
104
            if plot_control_point:
105
                ax.scatter(y=control_point[0], x=control_point[2], c='r', s=2)
106
            ax.set_xticklabels([])
107
            ax.set_yticklabels([])
108
            if  plot_labels:
109
                img_label = np.zeros((data.shape[0], data.shape[2]))
110
                img_label[:, diff_spacing:data.shape[2] - diff_spacing] = labels[:, :, control_point[2]]
111
                ax = plt.subplot(outer_grid[i + nbviews])
112
                ax.imshow(img, cmap=cm.gray, aspect=(data.shape[1] / data.shape[0]))
113
                ax.imshow(img_label,vmin=0,vmax=classes, cmap=my_cm,alpha=alpha, aspect=(data.shape[1] / data.shape[0]))
114
                if plot_control_point:
115
                    ax.scatter(y=control_point[0], x=control_point[2], c='r', s=2)
116
                ax.set_xticklabels([])
117
                ax.set_yticklabels([])
118
119
    #fig.subplots_adjust(wspace=0.001, hspace=0.001)
120
121
    plt.subplots_adjust(0,0,1,1,0,0)
122
    for ax in fig.axes:
123
        ax.axis('off')
124
        ax.margins(0,0)
125
        ax.xaxis.set_major_locator(plt.NullLocator())
126
        ax.yaxis.set_major_locator(plt.NullLocator())
127
128
    plt.savefig(savepath, transparent=True,bbbox_inches='tight',pad_inches=0,dpi=300)
129
    plt.close(fig)
130
131