Switch to unified view

a b/(1) PyTorch_HistoNet/util/CMC.py
1
import os
2
import matplotlib.pyplot as plt
3
import matplotlib.lines as mlines
4
5
6
default_color = ['r','g','b','c','m','y','orange','brown']
7
default_marker = ['*','o','s','v','X','*','.','P']
8
default_fontsize = 48
9
10
class CMC:
11
    def __init__(self,cmc_dict, color=default_color, marker=default_marker, fontsize=default_fontsize):
12
        self.color = color
13
        self.marker = marker
14
        self.cmc_dict = cmc_dict
15
        self.fontsize = fontsize
16
        self.font = {'family': 'DejaVu Sans', 'weight': 'normal', 'size': self.fontsize}
17
        self.sizeX = 20
18
        self.sizeY = 10
19
        self.lw = 10
20
        self.ms = 40
21
22
    def plot(self, title, rank=20, xlabel='Rank', ylabel='Matching Rates (%)', show_grid=True):
23
        fig, ax = plt.subplots(figsize=(self.sizeX, self.sizeY))
24
        fig.suptitle(title)
25
        x = list(range(0, rank+1, 5))
26
        plt.ylim(0.8, 1.0)
27
        plt.xlim(1, rank)
28
        plt.xlabel(xlabel)
29
        plt.ylabel(ylabel)
30
        plt.xticks(x)
31
        plt.grid(show_grid)
32
33
        method_name = []
34
        i = 0
35
        for name in self.cmc_dict.keys():
36
            if rank < len(self.cmc_dict[name]):
37
                temp_cmc = self.cmc_dict[name][:rank]
38
                r = list(range(1, rank+1))
39
            else:
40
                temp_cmc = self.cmc_dict[name]
41
                r = list(range(1, len(temp_cmc)+1))
42
43
            if name == list(self.cmc_dict.keys())[-1]:
44
                #globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[0], marker=self.marker[0], label='{:.1f}% {}'.format(self.cmc_dict[name][0]*100, name))
45
                globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[0], marker=self.marker[0], label='{}'.format(name), linewidth=self.lw, markersize=self.ms)
46
            else:
47
                #globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[i+1], marker=self.marker[i+1], label='{:.1f}% {}'.format(self.cmc_dict[name][0]*100, name))
48
                globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[i+1], marker=self.marker[i+1], label='{}'.format(name), linewidth=self.lw, markersize=self.ms)
49
                i = i+1
50
            ax.add_line(globals()[name])
51
            method_name.append(globals()[name])
52
53
        plt.legend(handles=method_name)
54
55
        plt.rc('xtick', labelsize=self.fontsize)
56
        plt.rc('ytick', labelsize=self.fontsize)
57
        plt.rc('font', **self.font)
58
59
        plt.show()
60
61
    def save(self, title, filename,
62
             rank=20, xlabel='Rank',
63
             ylabel='Matching Rates (%)', show_grid=True,
64
             save_path=os.getcwd(), format='png', **kwargs):
65
        fig, ax = plt.subplots(figsize=(self.sizeX, self.sizeY))
66
        fig.suptitle(title)
67
        x = list(range(0, rank+1, 5))
68
        plt.ylim(0.8, 1.0)
69
        plt.xlim(1, rank)
70
        plt.xlabel(xlabel)
71
        plt.ylabel(ylabel)
72
        plt.xticks(x)
73
        plt.grid(show_grid)
74
75
        method_name = []
76
        i = 0
77
        for name in self.cmc_dict.keys():
78
            if rank < len(self.cmc_dict[name]):
79
                temp_cmc = self.cmc_dict[name][:rank]
80
                r = list(range(1, rank+1))
81
            else:
82
                temp_cmc = self.cmc_dict[name]
83
                r = list(range(1, len(temp_cmc)+1))
84
85
            if name == list(self.cmc_dict.keys())[-1]:
86
                #globals()[name] = mlines.Line2D(r, temp_cmc, color='r', marker='*', label='{:.1f}% {}'.format(self.cmc_dict[name][0]*100, name))
87
                globals()[name] = mlines.Line2D(r, temp_cmc, color='r', marker='*', label='{}'.format(name), linewidth=self.lw, markersize=self.ms)
88
            else:
89
                #globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[i], marker=self.marker[i], label='{:.1f}% {}'.format(self.cmc_dict[name][0]*100, name))
90
                globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[i], marker=self.marker[i], label='{}'.format(name), linewidth=self.lw, markersize=self.ms)
91
                i = i+1
92
            ax.add_line(globals()[name])
93
            method_name.append(globals()[name])
94
95
        plt.legend(handles=method_name)
96
97
        plt.rc('xtick', labelsize=self.fontsize)
98
        plt.rc('ytick', labelsize=self.fontsize)
99
        plt.rc('font', **self.font)
100
101
        fig.savefig(os.path.join(save_path, filename+'.'+format),
102
                    format=format,
103
                    bbox_inches='tight',
104
                    pad_inches = 0, **kwargs)