Diff of /simdeep/plot_utils.py [000000] .. [53737a]

Switch to side-by-side view

--- a
+++ b/simdeep/plot_utils.py
@@ -0,0 +1,208 @@
+from sklearn.decomposition import PCA
+
+from colour import Color
+import numpy as np
+
+import matplotlib
+matplotlib.use('Agg')
+
+import seaborn as sns
+
+import pylab as plt
+import mpld3
+
+sns.set(color_codes=True)
+
+
+CSS = """
+table
+{
+  border-collapse: collapse;
+}
+th
+{
+  color: #ffffff;
+  background-color: #000000;
+}
+td
+{
+  background-color: #cccccc;
+}
+table, th, td
+{
+  font-family:Arial, Helvetica, sans-serif;
+  border: 1px solid black;
+  text-align: right;
+}
+"""
+
+
+class SampleHTML():
+    def __init__(self, name, label, proba, survival):
+        """
+        """
+        try:
+            nbdays, isdead = survival
+        except Exception:
+            nbdays, isdead = 'NaN', 'NaN'
+
+        self.html =  """
+<table border="1" class="dataframe">
+  <thead>
+    <tr style="text-align: right;">
+      <th></th>
+      <th>{0}</th>
+    </tr>
+  </thead>
+  <tbody>
+    <tr>
+      <th>Assigned class</th>
+      <td>{1}</td>
+    </tr>
+    <tr>
+      <th>class probability</th>
+      <td>{2}</td>
+    </tr>
+    <tr>
+      <th>nb days followed</th>
+      <td>{3}</td>
+    </tr>
+    <tr>
+      <th>Event</th>
+      <td>{4}</td>
+    </tr>
+  </tbody>
+</table>
+                """.format(name, label, proba, nbdays, isdead)
+
+
+def make_color_dict_from_r(labels):
+    """ """
+    labels_set = set(labels)
+
+    cin = Color('red')
+    cout = Color('#56f442')
+
+    gradient = list(map(lambda x:x.get_rgb(),
+                   cin.range_to(cout, len(labels_set))))
+
+    len_color = len(gradient)
+
+    if len_color > 2:
+        gradient[1] = Color('green').get_rgb()
+        gradient[2] = Color('blue').get_rgb()
+
+    if len_color > 3:
+        gradient[3] = Color('cyan').get_rgb()
+
+    if len_color > 4:
+        gradient[4] = Color('magenta').get_rgb()
+
+    if len_color > 5:
+        gradient[5] = Color('yellow').get_rgb()
+
+    return dict(zip(labels_set, gradient))
+
+
+def make_color_list(id_list):
+    """
+    According to an id_list define a color gradient
+    return {id:color}
+    """
+    try:
+        assert([Color(idc) for idc in id_list])
+    except Exception:
+        pass
+    else:
+        return id_list
+
+    color_dict = make_color_dict(id_list)
+
+    return np.array([color_dict[label] for label in id_list])
+
+def make_color_dict(id_list):
+    """
+    According to an id_list define a color gradient
+    return {id:color}
+    """
+    id_list = list(set(id_list))
+
+    first_c = Color("red")
+    middle_c = Color("green")
+
+    m_length1 = len(id_list)
+
+    gradient = list(first_c.range_to(middle_c, m_length1))
+
+    color_dict =  {id_list[i]: gradient[i].get_hex_l()
+                   for i in range(len(id_list))}
+
+    return color_dict
+
+def plot_kernel_plots(
+        test_labels,
+        test_labels_proba,
+        labels,
+        activities,
+        activities_test,
+        dataset,
+        path_html,
+        metadata_frame=None):
+    """
+    perform a html kernel plot
+    """
+    fig, ax = plt.subplots(figsize=(7, 7))
+
+    color_dict = make_color_dict_from_r(labels)
+    labels_c_test = np.array([color_dict[label] for label in test_labels])
+
+    decomp = PCA(n_components=2)
+    X, Y = decomp.fit_transform(activities).T
+
+    X_test, Y_test = decomp.transform(activities_test).T
+
+    for label in set(labels):
+        ax.scatter(
+            X_test[test_labels == label],
+            Y_test[test_labels == label],
+            s=40,
+            # linewidths=2.0,
+            alpha=1.0,
+           # marker='square_cross',
+            edgecolors='k',
+            zorder=2,
+           color=labels_c_test[test_labels == label],
+           label='test cluster nb {0}'.format(label))
+
+        sns.kdeplot(
+            X[labels == label],
+            Y[labels == label],
+            shade=True,
+            cmap=sns.dark_palette(color_dict[label], as_cmap=True),
+            color=color_dict[label],
+            ax=ax,
+            label='cluster nb {0}'.format(label),
+            zorder=1,
+            thresh=False,
+            alpha=0.7
+        )
+
+    survival_test = np.nan_to_num(dataset.survival_test)
+
+    labels = [SampleHTML(
+        name=dataset.sample_ids_test[i],
+        label=test_labels[i],
+        survival=np.asarray(survival_test[i])[0],
+        proba=test_labels_proba[i][test_labels[i]]).html
+              for i in range(len(test_labels))]
+
+    scatter = ax.plot(X_test, Y_test, 'o', color='b', mec='k',
+                      ms=15, mew=1, alpha=0.0, zorder=3,)[0]
+
+    tooltip = mpld3.plugins.PointHTMLTooltip(
+        scatter, labels, voffset=10, hoffset=10, css=CSS)
+    mpld3.plugins.connect(fig, tooltip)
+
+    mpld3.save_html(fig, path_html)
+
+    print('kde plot saved at:{0}'.format(path_html))