|
a |
|
b/paper/Benchmarking with real and synthetic datasets/plot_benchmark.py |
|
|
1 |
import pandas as pd |
|
|
2 |
import numpy as np |
|
|
3 |
import matplotlib.pyplot as plt |
|
|
4 |
import seaborn as sns |
|
|
5 |
|
|
|
6 |
dat_NB = pd.read_csv('result/result_VITAE_NB.csv', |
|
|
7 |
index_col=0)#.drop(['type'], axis=1) |
|
|
8 |
dat_NB = dat_NB[dat_NB.method == 'modified_map'] |
|
|
9 |
dat_NB.method = 'VITAE_NB' |
|
|
10 |
|
|
|
11 |
dat_Gaussian = pd.read_csv('result/result_VITAE_Gaussian.csv', |
|
|
12 |
index_col=0)#.drop(['type'], axis=1) |
|
|
13 |
dat_Gaussian = dat_Gaussian[dat_Gaussian.method == 'modified_map'] |
|
|
14 |
dat_Gaussian.method = 'VITAE_Gauss' |
|
|
15 |
|
|
|
16 |
dat_other = pd.read_csv('result/result_other_methods.csv', |
|
|
17 |
index_col=0)#.drop(['type'], axis=1) |
|
|
18 |
|
|
|
19 |
dat = pd.concat([dat_NB, dat_Gaussian, dat_other]) |
|
|
20 |
|
|
|
21 |
sources = ['dyngen','our model','real'] |
|
|
22 |
scores = ['GED score','IM score','ARI','GRI','PDT score'] |
|
|
23 |
cmaps = ['YlOrRd_r', 'YlGn_r', 'RdPu_r', 'PuBu_r', 'BuGn_r'] |
|
|
24 |
|
|
|
25 |
rotation_xticklabels = 25 |
|
|
26 |
|
|
|
27 |
sns.set(font_scale=1.5, rc={'axes.facecolor':(0.85,0.85,0.85), 'xtick.labelsize':14, 'ytick.labelsize':11}) |
|
|
28 |
fig, ax = plt.subplots(3, 5, gridspec_kw={'height_ratios':[6, 2, 8], 'width_ratios' :[1,1,1,1,1]}, figsize = (20,10)) |
|
|
29 |
for i in range(5): |
|
|
30 |
for j in range(3): |
|
|
31 |
vmin = dat[scores[i]].min() |
|
|
32 |
vmax = dat[scores[i]].max() |
|
|
33 |
dat_t = dat[dat.source == sources[j]] |
|
|
34 |
dat_t = dat_t[['data','method',scores[i]]].pivot('data', 'method', scores[i]) |
|
|
35 |
if j == 0: |
|
|
36 |
ax[j][i].set_title(scores[i], fontweight='bold', fontsize = 18) |
|
|
37 |
if (j == 2) & (i == 0): |
|
|
38 |
sns.heatmap(dat_t, ax = ax[j][i], cbar = True, |
|
|
39 |
cbar_kws={"orientation": "horizontal", "pad": 0.2}, vmin=vmin, vmax=vmax, cmap = cmaps[i]) |
|
|
40 |
ax[j][i].set_xticklabels(ax[j][i].get_xticklabels(), rotation=rotation_xticklabels, ha="center") |
|
|
41 |
# ax[j][i].set_yticklabels(ax[j][i].get_yticklabels(), fontsize=12)#rotation=30) |
|
|
42 |
elif i == 0: |
|
|
43 |
sns.heatmap(dat_t, ax = ax[j][i], xticklabels=False, cbar = False, vmin=vmin, vmax=vmax, cmap = cmaps[i]) |
|
|
44 |
# ax[j][i].set_yticklabels(ax[j][i].get_yticklabels(), rotation=30) |
|
|
45 |
elif j == 2: |
|
|
46 |
sns.heatmap(dat_t, ax = ax[j][i], yticklabels=False, cbar = True, |
|
|
47 |
cbar_kws={"orientation": "horizontal", "pad": 0.2}, vmin=vmin, vmax=vmax, cmap = cmaps[i]) |
|
|
48 |
ax[j][i].set_xticklabels(ax[j][i].get_xticklabels(), rotation=rotation_xticklabels, ha="center") |
|
|
49 |
else: |
|
|
50 |
sns.heatmap(dat_t, ax = ax[j][i], xticklabels=False, yticklabels=False, cbar = False, vmin=vmin, vmax=vmax, cmap = cmaps[i]) |
|
|
51 |
if i == 4: |
|
|
52 |
ax[j][i].set_ylabel(sources[j], rotation=270, fontweight='bold', fontsize = 18, labelpad=20) |
|
|
53 |
ax[j][i].yaxis.set_label_position("right") |
|
|
54 |
else: |
|
|
55 |
ax[j][i].set_ylabel(None) |
|
|
56 |
ax[j][i].set_xlabel(None) |
|
|
57 |
|
|
|
58 |
plt.tight_layout() |
|
|
59 |
plt.subplots_adjust(wspace=0.2, hspace=0.1) |
|
|
60 |
fig.savefig('result/comp_heatmap.pdf', bbox_inches='tight') |