Diff of /segmentation/plot_logs.py [000000] .. [18498b]

Switch to unified view

a b/segmentation/plot_logs.py
1
'''
2
Copyright (c) Microsoft Corporation. All rights reserved.
3
Licensed under the MIT License.
4
'''
5
import pandas as pd 
6
import matplotlib.pyplot as plt
7
import numpy as np 
8
import os 
9
import glob
10
import sys
11
config_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
12
sys.path.append(config_dir)
13
from config import RESULTS_FOLDER
14
15
# %%
16
def plot_train_logs(train_fpaths, valid_fpaths, network_names):
17
    train_dfs = [pd.read_csv(path) for path in train_fpaths]
18
    valid_dfs = [pd.read_csv(path) for path in valid_fpaths]
19
20
    train_losses = [df['Loss'].values for df in train_dfs]
21
    valid_metrics = [df['Metric'].values for df in valid_dfs]
22
    train_epochs = [np.arange(len(train_loss))+1 for train_loss in train_losses]
23
    valid_epochs = [2*(np.arange(len(valid_metric))+1) for valid_metric in valid_metrics]
24
    min_losses = [np.min(train_loss) for train_loss in train_losses]
25
    min_losses_epoch = [np.argmin(train_loss) + 1 for train_loss in train_losses]
26
    max_dscs = [np.max(valid_metric) for valid_metric in valid_metrics]
27
    max_dscs_epoch = [2*(np.argmax(valid_metric)+1) for valid_metric in valid_metrics]
28
    fig, ax = plt.subplots(1,2, figsize=(20,10))
29
    fig.patch.set_facecolor('white')
30
    fig.patch.set_alpha(1)
31
32
    for i in range(len(train_losses)):
33
        ax[0].plot(train_epochs[i], train_losses[i])
34
        ax[1].plot(valid_epochs[i], valid_metrics[i])
35
        ax[0].plot(min_losses_epoch[i], min_losses[i], '-o', color='red')
36
        ax[1].plot(max_dscs_epoch[i], max_dscs[i], '-o', color='red')
37
        
38
        ax[0].text(np.min(train_epochs[i]), np.min(train_losses[i]), f'Total epochs: {len(train_epochs[i])}', fontsize=15)
39
        
40
    legend_labels_trainloss = [f"{network_names[i]}; Min loss: {round(min_losses[i], 4)} ({len(train_epochs[i])})" for i in range(len(network_names))]
41
    legend_labels_validdice = [f"{network_names[i]}; Max DSC: {round(max_dscs[i], 4)} ({len(valid_epochs[i])})" for i in range(len(network_names))]
42
43
    ax[0].legend(legend_labels_trainloss, fontsize=16)
44
    ax[1].legend(legend_labels_validdice, fontsize=16)
45
    ax[0].set_title('Train loss', fontsize=25)
46
    ax[1].set_title('Valid DSC', fontsize=25)
47
    ax[0].set_ylabel('Dice loss', fontsize=20)
48
    ax[1].set_ylabel('Dice score', fontsize=20)
49
    ax[0].grid(True)
50
    ax[1].grid(True)
51
    plt.show()
52
53
#%%
54
fold = 0
55
network = ['unet']
56
inputsize = [192, 192, 160, 128]
57
p = 2
58
inputsize_dict = {
59
    'unet': 192,
60
    'attentionunet': 192,
61
    'segresnet': 192,
62
    'dynunet': 160,
63
    'unetr': 160,
64
    'swinunetr': 128
65
}
66
67
experiment_code = [f"{network[i]}_fold{fold}_randcrop{inputsize[i]}" for i in range(len(network))]
68
save_logs_dir = os.path.join(RESULTS_FOLDER, 'logs')
69
save_logs_folders = [os.path.join(save_logs_dir, 'fold'+str(fold), network[i], experiment_code[i]) for i in range(len(experiment_code))]
70
train_fpaths = [os.path.join(save_logs_folders[i], 'trainlog_gpu0.csv') for i in range(len(save_logs_folders))]
71
valid_fpaths = [os.path.join(save_logs_folders[i], 'validlog_gpu0.csv') for i in range(len(save_logs_folders))]
72
legend_lbls = [f'{network[i]}, N = {inputsize[i]}' for i in range(len(network))]
73
plot_train_logs(train_fpaths, valid_fpaths, legend_lbls)
74