|
a |
|
b/graphs.py |
|
|
1 |
import plotly.express as px |
|
|
2 |
import matplotlib.pyplot as plt |
|
|
3 |
import pandas as pd |
|
|
4 |
from skimage.transform import resize as skires |
|
|
5 |
from utils import parse_args |
|
|
6 |
|
|
|
7 |
class Graphs: |
|
|
8 |
def __init__(self): |
|
|
9 |
self.config = parse_args() |
|
|
10 |
self.network = self.config.network |
|
|
11 |
self.name = self.config.name |
|
|
12 |
|
|
|
13 |
def visualize(self, epochs, scores, legends, x_label, y_label, title): |
|
|
14 |
colors = ['red', 'blue', 'green', 'purple', 'orange', 'black'] |
|
|
15 |
for score, legend, color in zip(scores, legends, colors): |
|
|
16 |
plt.plot(epochs, score, color, label=legend) |
|
|
17 |
|
|
|
18 |
plt.legend(loc='upper right') |
|
|
19 |
plt.title(title) |
|
|
20 |
plt.xlabel(x_label) |
|
|
21 |
plt.ylabel(y_label) |
|
|
22 |
plt.legend() |
|
|
23 |
# plt.ylim(0.0, 1.0) |
|
|
24 |
plt.savefig(f"outputs/{self.network}/{self.name}/graphs/graph2.jpeg") |
|
|
25 |
plt.show() |
|
|
26 |
|
|
|
27 |
def read_data(self, type): |
|
|
28 |
df = pd.read_csv(f"outputs/{self.network}/{self.name}/{type}.csv") |
|
|
29 |
fields = df.columns.tolist() |
|
|
30 |
metrics = [] |
|
|
31 |
for column in df.columns: |
|
|
32 |
metrics.append(df[column].tolist()) |
|
|
33 |
|
|
|
34 |
return df, fields, metrics |
|
|
35 |
|
|
|
36 |
def training_plotting(self): |
|
|
37 |
_, fields, metrics = self.read_data(type='epo_log') |
|
|
38 |
|
|
|
39 |
# mapping |
|
|
40 |
options = { |
|
|
41 |
'epoch': 0, |
|
|
42 |
'lr': 1, |
|
|
43 |
|
|
|
44 |
'Train loss': 2, |
|
|
45 |
'Train ce loss': 3, |
|
|
46 |
'Train dice score': 4, |
|
|
47 |
'Train dice loss': 5, |
|
|
48 |
'Train iou score': 6, |
|
|
49 |
'Train iou loss': 7, |
|
|
50 |
'Train hausdorff': 8, |
|
|
51 |
|
|
|
52 |
'Val loss': 9, |
|
|
53 |
'Val ce loss': 10, |
|
|
54 |
'Val dice score': 11, |
|
|
55 |
'Val dice loss': 12, |
|
|
56 |
'Val iou score': 13, |
|
|
57 |
'Val iou loss': 14, |
|
|
58 |
'Val hausdorff': 15, |
|
|
59 |
} |
|
|
60 |
|
|
|
61 |
iters = [i for i in range(1, (len(metrics[0])) + 1)] |
|
|
62 |
|
|
|
63 |
train_hausdorff = metrics[options['Train hausdorff']] |
|
|
64 |
train_hausdorff = [x / 100 for x in train_hausdorff] |
|
|
65 |
|
|
|
66 |
val_hausdorff = metrics[options['Val hausdorff']] |
|
|
67 |
val_hausdorff = [x / 100 for x in val_hausdorff] |
|
|
68 |
|
|
|
69 |
|
|
|
70 |
self.visualize( |
|
|
71 |
iters, |
|
|
72 |
[ |
|
|
73 |
train_hausdorff, val_hausdorff, |
|
|
74 |
metrics[options['Train ce loss']], metrics[options['Val ce loss']], |
|
|
75 |
metrics[options['Train iou loss']], metrics[options['Val iou loss']], |
|
|
76 |
# metrics[options['Train loss']], metrics[options['Val loss']] |
|
|
77 |
], |
|
|
78 |
|
|
|
79 |
[ |
|
|
80 |
# fields[options['Train hausdorff']], fields[options['Val hausdorff']], |
|
|
81 |
fields[options['Train hausdorff']] + ' (/100)', fields[options['Val hausdorff']] + ' (/100)', |
|
|
82 |
fields[options['Train ce loss']], fields[options['Val ce loss']], |
|
|
83 |
fields[options['Train iou loss']], fields[options['Val iou loss']] |
|
|
84 |
], |
|
|
85 |
|
|
|
86 |
'Epochs', 'Scores', 'Training results', |
|
|
87 |
) |
|
|
88 |
|
|
|
89 |
# Only use for testing |
|
|
90 |
def boxplot(self): |
|
|
91 |
df_dice = pd.read_csv(f"outputs/{self.config.name}/infer_dice_class.csv") |
|
|
92 |
df_iou = pd.read_csv(f"outputs/{self.config.name}/infer_iou_class.csv") |
|
|
93 |
|
|
|
94 |
df_dice['type'] = 'dice' |
|
|
95 |
df_iou['type'] = 'iou' |
|
|
96 |
|
|
|
97 |
df_combined = pd.concat([df_dice, df_iou]) |
|
|
98 |
df_combined.reset_index(drop=True, inplace=True) |
|
|
99 |
df_final = pd.melt(df_combined, id_vars=['type'], var_name='class', value_name='score') |
|
|
100 |
df_final.sort_values(['type', 'class'], inplace=True) |
|
|
101 |
df_final.reset_index(drop=True, inplace=True) |
|
|
102 |
|
|
|
103 |
fig = px.box(df_final, x="class", y="score", color="type") |
|
|
104 |
fig.update_traces(quartilemethod="exclusive") |
|
|
105 |
fig.update_layout(width=700, height=700) |
|
|
106 |
fig.show() |
|
|
107 |
|
|
|
108 |
if __name__ == '__main__': |
|
|
109 |
graph = Graphs() |
|
|
110 |
graph.training_plotting() |