|
a |
|
b/(1) PyTorch_HistoNet/util/CMC.py |
|
|
1 |
import os |
|
|
2 |
import matplotlib.pyplot as plt |
|
|
3 |
import matplotlib.lines as mlines |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
default_color = ['r','g','b','c','m','y','orange','brown'] |
|
|
7 |
default_marker = ['*','o','s','v','X','*','.','P'] |
|
|
8 |
default_fontsize = 48 |
|
|
9 |
|
|
|
10 |
class CMC: |
|
|
11 |
def __init__(self,cmc_dict, color=default_color, marker=default_marker, fontsize=default_fontsize): |
|
|
12 |
self.color = color |
|
|
13 |
self.marker = marker |
|
|
14 |
self.cmc_dict = cmc_dict |
|
|
15 |
self.fontsize = fontsize |
|
|
16 |
self.font = {'family': 'DejaVu Sans', 'weight': 'normal', 'size': self.fontsize} |
|
|
17 |
self.sizeX = 20 |
|
|
18 |
self.sizeY = 10 |
|
|
19 |
self.lw = 10 |
|
|
20 |
self.ms = 40 |
|
|
21 |
|
|
|
22 |
def plot(self, title, rank=20, xlabel='Rank', ylabel='Matching Rates (%)', show_grid=True): |
|
|
23 |
fig, ax = plt.subplots(figsize=(self.sizeX, self.sizeY)) |
|
|
24 |
fig.suptitle(title) |
|
|
25 |
x = list(range(0, rank+1, 5)) |
|
|
26 |
plt.ylim(0.8, 1.0) |
|
|
27 |
plt.xlim(1, rank) |
|
|
28 |
plt.xlabel(xlabel) |
|
|
29 |
plt.ylabel(ylabel) |
|
|
30 |
plt.xticks(x) |
|
|
31 |
plt.grid(show_grid) |
|
|
32 |
|
|
|
33 |
method_name = [] |
|
|
34 |
i = 0 |
|
|
35 |
for name in self.cmc_dict.keys(): |
|
|
36 |
if rank < len(self.cmc_dict[name]): |
|
|
37 |
temp_cmc = self.cmc_dict[name][:rank] |
|
|
38 |
r = list(range(1, rank+1)) |
|
|
39 |
else: |
|
|
40 |
temp_cmc = self.cmc_dict[name] |
|
|
41 |
r = list(range(1, len(temp_cmc)+1)) |
|
|
42 |
|
|
|
43 |
if name == list(self.cmc_dict.keys())[-1]: |
|
|
44 |
#globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[0], marker=self.marker[0], label='{:.1f}% {}'.format(self.cmc_dict[name][0]*100, name)) |
|
|
45 |
globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[0], marker=self.marker[0], label='{}'.format(name), linewidth=self.lw, markersize=self.ms) |
|
|
46 |
else: |
|
|
47 |
#globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[i+1], marker=self.marker[i+1], label='{:.1f}% {}'.format(self.cmc_dict[name][0]*100, name)) |
|
|
48 |
globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[i+1], marker=self.marker[i+1], label='{}'.format(name), linewidth=self.lw, markersize=self.ms) |
|
|
49 |
i = i+1 |
|
|
50 |
ax.add_line(globals()[name]) |
|
|
51 |
method_name.append(globals()[name]) |
|
|
52 |
|
|
|
53 |
plt.legend(handles=method_name) |
|
|
54 |
|
|
|
55 |
plt.rc('xtick', labelsize=self.fontsize) |
|
|
56 |
plt.rc('ytick', labelsize=self.fontsize) |
|
|
57 |
plt.rc('font', **self.font) |
|
|
58 |
|
|
|
59 |
plt.show() |
|
|
60 |
|
|
|
61 |
def save(self, title, filename, |
|
|
62 |
rank=20, xlabel='Rank', |
|
|
63 |
ylabel='Matching Rates (%)', show_grid=True, |
|
|
64 |
save_path=os.getcwd(), format='png', **kwargs): |
|
|
65 |
fig, ax = plt.subplots(figsize=(self.sizeX, self.sizeY)) |
|
|
66 |
fig.suptitle(title) |
|
|
67 |
x = list(range(0, rank+1, 5)) |
|
|
68 |
plt.ylim(0.8, 1.0) |
|
|
69 |
plt.xlim(1, rank) |
|
|
70 |
plt.xlabel(xlabel) |
|
|
71 |
plt.ylabel(ylabel) |
|
|
72 |
plt.xticks(x) |
|
|
73 |
plt.grid(show_grid) |
|
|
74 |
|
|
|
75 |
method_name = [] |
|
|
76 |
i = 0 |
|
|
77 |
for name in self.cmc_dict.keys(): |
|
|
78 |
if rank < len(self.cmc_dict[name]): |
|
|
79 |
temp_cmc = self.cmc_dict[name][:rank] |
|
|
80 |
r = list(range(1, rank+1)) |
|
|
81 |
else: |
|
|
82 |
temp_cmc = self.cmc_dict[name] |
|
|
83 |
r = list(range(1, len(temp_cmc)+1)) |
|
|
84 |
|
|
|
85 |
if name == list(self.cmc_dict.keys())[-1]: |
|
|
86 |
#globals()[name] = mlines.Line2D(r, temp_cmc, color='r', marker='*', label='{:.1f}% {}'.format(self.cmc_dict[name][0]*100, name)) |
|
|
87 |
globals()[name] = mlines.Line2D(r, temp_cmc, color='r', marker='*', label='{}'.format(name), linewidth=self.lw, markersize=self.ms) |
|
|
88 |
else: |
|
|
89 |
#globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[i], marker=self.marker[i], label='{:.1f}% {}'.format(self.cmc_dict[name][0]*100, name)) |
|
|
90 |
globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[i], marker=self.marker[i], label='{}'.format(name), linewidth=self.lw, markersize=self.ms) |
|
|
91 |
i = i+1 |
|
|
92 |
ax.add_line(globals()[name]) |
|
|
93 |
method_name.append(globals()[name]) |
|
|
94 |
|
|
|
95 |
plt.legend(handles=method_name) |
|
|
96 |
|
|
|
97 |
plt.rc('xtick', labelsize=self.fontsize) |
|
|
98 |
plt.rc('ytick', labelsize=self.fontsize) |
|
|
99 |
plt.rc('font', **self.font) |
|
|
100 |
|
|
|
101 |
fig.savefig(os.path.join(save_path, filename+'.'+format), |
|
|
102 |
format=format, |
|
|
103 |
bbox_inches='tight', |
|
|
104 |
pad_inches = 0, **kwargs) |