Diff of /simdeep/plot_utils.py [000000] .. [53737a]

Switch to unified view

a b/simdeep/plot_utils.py
1
from sklearn.decomposition import PCA
2
3
from colour import Color
4
import numpy as np
5
6
import matplotlib
7
matplotlib.use('Agg')
8
9
import seaborn as sns
10
11
import pylab as plt
12
import mpld3
13
14
sns.set(color_codes=True)
15
16
17
CSS = """
18
table
19
{
20
  border-collapse: collapse;
21
}
22
th
23
{
24
  color: #ffffff;
25
  background-color: #000000;
26
}
27
td
28
{
29
  background-color: #cccccc;
30
}
31
table, th, td
32
{
33
  font-family:Arial, Helvetica, sans-serif;
34
  border: 1px solid black;
35
  text-align: right;
36
}
37
"""
38
39
40
class SampleHTML():
41
    def __init__(self, name, label, proba, survival):
42
        """
43
        """
44
        try:
45
            nbdays, isdead = survival
46
        except Exception:
47
            nbdays, isdead = 'NaN', 'NaN'
48
49
        self.html =  """
50
<table border="1" class="dataframe">
51
  <thead>
52
    <tr style="text-align: right;">
53
      <th></th>
54
      <th>{0}</th>
55
    </tr>
56
  </thead>
57
  <tbody>
58
    <tr>
59
      <th>Assigned class</th>
60
      <td>{1}</td>
61
    </tr>
62
    <tr>
63
      <th>class probability</th>
64
      <td>{2}</td>
65
    </tr>
66
    <tr>
67
      <th>nb days followed</th>
68
      <td>{3}</td>
69
    </tr>
70
    <tr>
71
      <th>Event</th>
72
      <td>{4}</td>
73
    </tr>
74
  </tbody>
75
</table>
76
                """.format(name, label, proba, nbdays, isdead)
77
78
79
def make_color_dict_from_r(labels):
80
    """ """
81
    labels_set = set(labels)
82
83
    cin = Color('red')
84
    cout = Color('#56f442')
85
86
    gradient = list(map(lambda x:x.get_rgb(),
87
                   cin.range_to(cout, len(labels_set))))
88
89
    len_color = len(gradient)
90
91
    if len_color > 2:
92
        gradient[1] = Color('green').get_rgb()
93
        gradient[2] = Color('blue').get_rgb()
94
95
    if len_color > 3:
96
        gradient[3] = Color('cyan').get_rgb()
97
98
    if len_color > 4:
99
        gradient[4] = Color('magenta').get_rgb()
100
101
    if len_color > 5:
102
        gradient[5] = Color('yellow').get_rgb()
103
104
    return dict(zip(labels_set, gradient))
105
106
107
def make_color_list(id_list):
108
    """
109
    According to an id_list define a color gradient
110
    return {id:color}
111
    """
112
    try:
113
        assert([Color(idc) for idc in id_list])
114
    except Exception:
115
        pass
116
    else:
117
        return id_list
118
119
    color_dict = make_color_dict(id_list)
120
121
    return np.array([color_dict[label] for label in id_list])
122
123
def make_color_dict(id_list):
124
    """
125
    According to an id_list define a color gradient
126
    return {id:color}
127
    """
128
    id_list = list(set(id_list))
129
130
    first_c = Color("red")
131
    middle_c = Color("green")
132
133
    m_length1 = len(id_list)
134
135
    gradient = list(first_c.range_to(middle_c, m_length1))
136
137
    color_dict =  {id_list[i]: gradient[i].get_hex_l()
138
                   for i in range(len(id_list))}
139
140
    return color_dict
141
142
def plot_kernel_plots(
143
        test_labels,
144
        test_labels_proba,
145
        labels,
146
        activities,
147
        activities_test,
148
        dataset,
149
        path_html,
150
        metadata_frame=None):
151
    """
152
    perform a html kernel plot
153
    """
154
    fig, ax = plt.subplots(figsize=(7, 7))
155
156
    color_dict = make_color_dict_from_r(labels)
157
    labels_c_test = np.array([color_dict[label] for label in test_labels])
158
159
    decomp = PCA(n_components=2)
160
    X, Y = decomp.fit_transform(activities).T
161
162
    X_test, Y_test = decomp.transform(activities_test).T
163
164
    for label in set(labels):
165
        ax.scatter(
166
            X_test[test_labels == label],
167
            Y_test[test_labels == label],
168
            s=40,
169
            # linewidths=2.0,
170
            alpha=1.0,
171
           # marker='square_cross',
172
            edgecolors='k',
173
            zorder=2,
174
           color=labels_c_test[test_labels == label],
175
           label='test cluster nb {0}'.format(label))
176
177
        sns.kdeplot(
178
            X[labels == label],
179
            Y[labels == label],
180
            shade=True,
181
            cmap=sns.dark_palette(color_dict[label], as_cmap=True),
182
            color=color_dict[label],
183
            ax=ax,
184
            label='cluster nb {0}'.format(label),
185
            zorder=1,
186
            thresh=False,
187
            alpha=0.7
188
        )
189
190
    survival_test = np.nan_to_num(dataset.survival_test)
191
192
    labels = [SampleHTML(
193
        name=dataset.sample_ids_test[i],
194
        label=test_labels[i],
195
        survival=np.asarray(survival_test[i])[0],
196
        proba=test_labels_proba[i][test_labels[i]]).html
197
              for i in range(len(test_labels))]
198
199
    scatter = ax.plot(X_test, Y_test, 'o', color='b', mec='k',
200
                      ms=15, mew=1, alpha=0.0, zorder=3,)[0]
201
202
    tooltip = mpld3.plugins.PointHTMLTooltip(
203
        scatter, labels, voffset=10, hoffset=10, css=CSS)
204
    mpld3.plugins.connect(fig, tooltip)
205
206
    mpld3.save_html(fig, path_html)
207
208
    print('kde plot saved at:{0}'.format(path_html))