a b/paper/Benchmarking with real and synthetic datasets/plot_benchmark.py
1
import pandas as pd
2
import numpy as np
3
import matplotlib.pyplot as plt
4
import seaborn as sns
5
6
dat_NB = pd.read_csv('result/result_VITAE_NB.csv',
7
                        index_col=0)#.drop(['type'], axis=1)
8
dat_NB = dat_NB[dat_NB.method == 'modified_map']
9
dat_NB.method = 'VITAE_NB'
10
11
dat_Gaussian = pd.read_csv('result/result_VITAE_Gaussian.csv',
12
                        index_col=0)#.drop(['type'], axis=1)
13
dat_Gaussian = dat_Gaussian[dat_Gaussian.method == 'modified_map']
14
dat_Gaussian.method = 'VITAE_Gauss'
15
16
dat_other = pd.read_csv('result/result_other_methods.csv',
17
                        index_col=0)#.drop(['type'], axis=1)
18
19
dat = pd.concat([dat_NB, dat_Gaussian, dat_other])
20
21
sources = ['dyngen','our model','real']
22
scores = ['GED score','IM score','ARI','GRI','PDT score']
23
cmaps = ['YlOrRd_r', 'YlGn_r', 'RdPu_r', 'PuBu_r', 'BuGn_r']
24
25
rotation_xticklabels = 25
26
27
sns.set(font_scale=1.5, rc={'axes.facecolor':(0.85,0.85,0.85), 'xtick.labelsize':14, 'ytick.labelsize':11})
28
fig, ax = plt.subplots(3, 5, gridspec_kw={'height_ratios':[6, 2, 8], 'width_ratios' :[1,1,1,1,1]}, figsize = (20,10))
29
for i in range(5):
30
    for j in range(3):
31
        vmin = dat[scores[i]].min()
32
        vmax = dat[scores[i]].max() 
33
        dat_t = dat[dat.source == sources[j]]
34
        dat_t = dat_t[['data','method',scores[i]]].pivot('data', 'method', scores[i])
35
        if  j == 0:
36
            ax[j][i].set_title(scores[i], fontweight='bold', fontsize = 18)
37
        if (j == 2) & (i == 0):
38
            sns.heatmap(dat_t, ax = ax[j][i], cbar = True, 
39
                cbar_kws={"orientation": "horizontal", "pad": 0.2}, vmin=vmin, vmax=vmax, cmap = cmaps[i])
40
            ax[j][i].set_xticklabels(ax[j][i].get_xticklabels(), rotation=rotation_xticklabels, ha="center")
41
            # ax[j][i].set_yticklabels(ax[j][i].get_yticklabels(), fontsize=12)#rotation=30)
42
        elif i == 0:
43
            sns.heatmap(dat_t, ax = ax[j][i], xticklabels=False, cbar = False, vmin=vmin, vmax=vmax, cmap = cmaps[i])
44
            # ax[j][i].set_yticklabels(ax[j][i].get_yticklabels(), rotation=30)
45
        elif j == 2:
46
            sns.heatmap(dat_t, ax = ax[j][i], yticklabels=False, cbar = True, 
47
                cbar_kws={"orientation": "horizontal", "pad": 0.2}, vmin=vmin, vmax=vmax, cmap = cmaps[i])
48
            ax[j][i].set_xticklabels(ax[j][i].get_xticklabels(), rotation=rotation_xticklabels, ha="center")
49
        else:
50
            sns.heatmap(dat_t, ax = ax[j][i], xticklabels=False, yticklabels=False, cbar = False, vmin=vmin, vmax=vmax, cmap = cmaps[i])
51
        if i == 4:
52
            ax[j][i].set_ylabel(sources[j], rotation=270, fontweight='bold', fontsize = 18, labelpad=20)
53
            ax[j][i].yaxis.set_label_position("right")
54
        else:
55
            ax[j][i].set_ylabel(None)
56
        ax[j][i].set_xlabel(None)
57
        
58
plt.tight_layout()
59
plt.subplots_adjust(wspace=0.2, hspace=0.1)
60
fig.savefig('result/comp_heatmap.pdf', bbox_inches='tight')