Diff of /graphs.py [000000] .. [70e190]

Switch to unified view

a b/graphs.py
1
import plotly.express as px
2
import matplotlib.pyplot as plt
3
import pandas as pd
4
from skimage.transform import resize as skires
5
from utils import parse_args
6
7
class Graphs:
8
    def __init__(self):
9
        self.config = parse_args()
10
        self.network = self.config.network
11
        self.name = self.config.name
12
    
13
    def visualize(self, epochs, scores, legends, x_label, y_label, title):
14
        colors = ['red', 'blue', 'green', 'purple', 'orange', 'black'] 
15
        for score, legend, color in zip(scores, legends, colors):
16
            plt.plot(epochs, score, color, label=legend)
17
        
18
        plt.legend(loc='upper right')    
19
        plt.title(title)
20
        plt.xlabel(x_label)
21
        plt.ylabel(y_label)
22
        plt.legend()
23
        # plt.ylim(0.0, 1.0)
24
        plt.savefig(f"outputs/{self.network}/{self.name}/graphs/graph2.jpeg")
25
        plt.show()
26
27
    def read_data(self, type):
28
        df = pd.read_csv(f"outputs/{self.network}/{self.name}/{type}.csv")
29
        fields = df.columns.tolist()  
30
        metrics = []
31
        for column in df.columns:
32
            metrics.append(df[column].tolist())
33
            
34
        return df, fields, metrics
35
36
    def training_plotting(self):
37
        _, fields, metrics = self.read_data(type='epo_log')
38
            
39
        # mapping
40
        options = {
41
            'epoch': 0,
42
            'lr': 1,
43
            
44
            'Train loss': 2,
45
            'Train ce loss': 3,
46
            'Train dice score': 4,
47
            'Train dice loss': 5,
48
            'Train iou score': 6,
49
            'Train iou loss': 7,
50
            'Train hausdorff': 8,
51
            
52
            'Val loss': 9,
53
            'Val ce loss': 10,
54
            'Val dice score': 11,
55
            'Val dice loss': 12,
56
            'Val iou score': 13,
57
            'Val iou loss': 14,
58
            'Val hausdorff': 15,
59
        }
60
            
61
        iters = [i for i in range(1, (len(metrics[0])) + 1)]
62
        
63
        train_hausdorff = metrics[options['Train hausdorff']]
64
        train_hausdorff = [x / 100 for x in train_hausdorff]
65
        
66
        val_hausdorff = metrics[options['Val hausdorff']]
67
        val_hausdorff = [x / 100 for x in val_hausdorff]
68
        
69
        
70
        self.visualize(
71
            iters, 
72
            [
73
            train_hausdorff, val_hausdorff,
74
            metrics[options['Train ce loss']], metrics[options['Val ce loss']],
75
            metrics[options['Train iou loss']], metrics[options['Val iou loss']],
76
            # metrics[options['Train loss']], metrics[options['Val loss']] 
77
            ],  
78
            
79
            [
80
            # fields[options['Train hausdorff']], fields[options['Val hausdorff']],
81
            fields[options['Train hausdorff']] + ' (/100)', fields[options['Val hausdorff']] + ' (/100)', 
82
            fields[options['Train ce loss']], fields[options['Val ce loss']],
83
            fields[options['Train iou loss']], fields[options['Val iou loss']] 
84
            ],      
85
            
86
            'Epochs', 'Scores', 'Training results', 
87
        )
88
    
89
    # Only use for testing 
90
    def boxplot(self):
91
        df_dice = pd.read_csv(f"outputs/{self.config.name}/infer_dice_class.csv")
92
        df_iou  = pd.read_csv(f"outputs/{self.config.name}/infer_iou_class.csv")
93
        
94
        df_dice['type'] = 'dice'
95
        df_iou['type'] = 'iou'
96
97
        df_combined = pd.concat([df_dice, df_iou])
98
        df_combined.reset_index(drop=True, inplace=True)
99
        df_final = pd.melt(df_combined, id_vars=['type'], var_name='class', value_name='score')
100
        df_final.sort_values(['type', 'class'], inplace=True)
101
        df_final.reset_index(drop=True, inplace=True)
102
103
        fig = px.box(df_final, x="class", y="score", color="type")
104
        fig.update_traces(quartilemethod="exclusive")
105
        fig.update_layout(width=700, height=700)
106
        fig.show()
107
108
if __name__ == '__main__':
109
    graph = Graphs()
110
    graph.training_plotting()