Diff of /train_ensemble.py [000000] .. [94d9b6]

Switch to unified view

a b/train_ensemble.py
1
import time
2
import torch
3
import torch.nn as nn
4
from torch.autograd import Variable
5
import torch.optim as optim
6
import torchvision
7
from torchvision import datasets, models
8
from torchvision import transforms as T
9
from torch.utils.data import DataLoader, Dataset
10
import numpy as np
11
import matplotlib.pyplot as plt
12
import os
13
import time
14
import pandas as pd
15
from skimage import io, transform
16
import matplotlib.image as mpimg
17
from PIL import Image
18
from sklearn.metrics import roc_auc_score
19
import torch.nn.functional as F
20
import scipy
21
import random
22
import pickle
23
import scipy.io as sio
24
import itertools
25
from scipy.ndimage.interpolation import shift
26
import copy
27
import warnings
28
warnings.filterwarnings("ignore")
29
plt.ion()
30
31
from utils import *
32
33
def train_pp_model(model,prediction_models, optimizer,dataloader, data_sizes, batch_size, name, num_epochs = 100,
34
                verbose = False, dice_loss = dice_loss):
35
    since = time.time()
36
    best_loss = np.inf
37
    best_dice_cl0 = 0
38
    best_dice_cl1 = 0
39
    best_dice_cl2 = 0
40
    best_score = 0
41
    best_model_wts = copy.deepcopy(model.state_dict())
42
    loss_hist = {'train':[],'validate':[]}
43
    dice_score_0_hist = {'train':[],'validate':[]}
44
    dice_score_1_hist = {'train':[],'validate':[]}
45
    dice_score_2_hist = {'train':[],'validate':[]}
46
    for i in prediction_models:
47
        for param in i.parameters():
48
            param.requires_grad = False
49
    for i in range(num_epochs):
50
        for phase in ['train', 'validate']:
51
            running_loss = 0
52
            running_dice_score_class_0 = 0
53
            running_dice_score_class_1 = 0
54
            running_dice_score_class_2 = 0
55
            
56
            if phase == 'train':
57
                model.train(True)
58
            else:
59
                model.train(False)
60
    
61
            for data in dataloader[phase]:
62
                optimizer.zero_grad()
63
                input, segF, segP, segT,_ = data
64
                input = Variable(input).cuda()
65
                true = Variable(segments(segF, segP, segT)).cuda()
66
                input_pp = []
67
                for j in prediction_models:
68
                    output = j(input)
69
                    preds = predict_pp(output)
70
                    input_pp.append(preds)
71
                input_pp = torch.cat(input_pp,dim = 1)
72
                output_pp = model(input_pp)
73
                loss = dice_loss(true,output_pp)
74
                if phase == 'train':
75
                    loss.backward()
76
                    optimizer.step()
77
                running_loss += loss.data[0] * batch_size
78
                dice_score_batch = dice_score(true,output_pp)
79
                running_dice_score_class_0 += dice_score_batch[0] * batch_size
80
                running_dice_score_class_1 += dice_score_batch[1] * batch_size
81
                running_dice_score_class_2 += dice_score_batch[2] * batch_size
82
            epoch_loss = running_loss/data_sizes[phase]
83
            loss_hist[phase].append(epoch_loss) 
84
            epoch_dice_score_0 = running_dice_score_class_0/data_sizes[phase]
85
            epoch_dice_score_1 = running_dice_score_class_1/data_sizes[phase]
86
            epoch_dice_score_2 = running_dice_score_class_2/data_sizes[phase]
87
            dice_score_0_hist[phase].append(epoch_dice_score_0)
88
            dice_score_1_hist[phase].append(epoch_dice_score_1)
89
            dice_score_2_hist[phase].append(epoch_dice_score_2)
90
            epoch_score = epoch_dice_score_0 + epoch_dice_score_1 + epoch_dice_score_2
91
            if verbose or i%10 == 0:
92
                print('Epoch: {}, Phase: {}, epoch loss: {:.4f}, Dice Score (class 0): {:.4f}, Dice Score (class 1): {:.4f},Dice Score (class 2): {:.4f}'.format(i,phase,epoch_loss, epoch_dice_score_0, epoch_dice_score_1, epoch_dice_score_2))
93
                print('-'*10)
94
            
95
#         if phase == 'validate' and epoch_loss < best_loss:
96
#             best_loss = epoch_loss
97
#             best_model_wts = copy.deepcopy(model.state_dict())
98
#             torch.save(model,name)
99
#             best_dice_cl0 = epoch_dice_score_0
100
#             best_dice_cl1 = epoch_dice_score_1
101
#             best_dice_cl2 = epoch_dice_score_2
102
        if phase == 'validate' and epoch_score > best_score:
103
            best_score = epoch_score
104
            best_model_wts = copy.deepcopy(model.state_dict())
105
            torch.save(model,name)
106
            best_dice_cl0 = epoch_dice_score_0
107
            best_dice_cl1 = epoch_dice_score_1
108
            best_dice_cl2 = epoch_dice_score_2
109
            best_loss = epoch_loss
110
    print('-'*50)    
111
    time_elapsed = time.time() - since
112
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
113
    print('Best val dice loss: {:4f}, dice score (class 0): {:.4f}, dice score (class 1): {:.4f},dice score (class 2): {:.4f}'\
114
          .format(best_loss, best_dice_cl0, best_dice_cl1, best_dice_cl2))
115
    
116
    model.load_state_dict(best_model_wts)
117
    
118
    return model, loss_hist, dice_score_0_hist, dice_score_1_hist, dice_score_2_hist