Diff of /train_2d.py [000000] .. [94d9b6]

Switch to unified view

a b/train_2d.py
1
import time
2
import torch
3
import torch.nn as nn
4
from torch.autograd import Variable
5
import torch.optim as optim
6
import torchvision
7
from torchvision import datasets, models
8
from torchvision import transforms as T
9
from torch.utils.data import DataLoader, Dataset
10
import numpy as np
11
import matplotlib.pyplot as plt
12
import os
13
import time
14
import pandas as pd
15
from skimage import io, transform
16
import matplotlib.image as mpimg
17
from PIL import Image
18
from sklearn.metrics import roc_auc_score
19
import torch.nn.functional as F
20
import scipy
21
import random
22
import pickle
23
import scipy.io as sio
24
import itertools
25
from scipy.ndimage.interpolation import shift
26
import copy
27
import warnings
28
warnings.filterwarnings("ignore")
29
plt.ion()
30
31
from utils import *
32
33
def train_model(model, optimizer,dataloader, data_sizes, batch_size, name, num_epochs = 100,
34
                verbose = False, dice_loss = dice_loss, noisy_labels = False):
35
    since = time.time()
36
    best_loss = np.inf
37
    best_dice_cl0 = 0
38
    best_dice_cl1 = 0
39
    best_dice_cl2 = 0
40
    best_model_wts = copy.deepcopy(model.state_dict())
41
    loss_hist = {'train':[],'validate':[]}
42
    dice_score_0_hist = {'train':[],'validate':[]}
43
    dice_score_1_hist = {'train':[],'validate':[]}
44
    dice_score_2_hist = {'train':[],'validate':[]}
45
    for i in range(num_epochs):
46
        for phase in ['train', 'validate']:
47
            running_loss = 0
48
            running_dice_score_class_0 = 0
49
            running_dice_score_class_1 = 0
50
            running_dice_score_class_2 = 0
51
            
52
            if phase == 'train':
53
                model.train(True)
54
            else:
55
                model.train(False)
56
    
57
            for data in dataloader[phase]:
58
                optimizer.zero_grad()
59
                input, segF, segP, segT,_ = data
60
                input = Variable(input).cuda()
61
                true = Variable(segments(segF, segP, segT)).cuda()
62
                output = model(input)
63
                if noisy_labels and phase == 'train':
64
                    noise = generate_noise(true)
65
                    true = true + true*noise
66
                if noisy_labels:
67
                    loss = entropy_loss(true,output)
68
                else:
69
                    loss = dice_loss(true,output)
70
                if phase == 'train':
71
                    loss.backward()
72
                    optimizer.step()
73
                running_loss += loss.data[0] * batch_size
74
                dice_score_batch = dice_score(true,output)
75
                running_dice_score_class_0 += dice_score_batch[0] * batch_size
76
                running_dice_score_class_1 += dice_score_batch[1] * batch_size
77
                running_dice_score_class_2 += dice_score_batch[2] * batch_size
78
            epoch_loss = running_loss/data_sizes[phase]
79
            loss_hist[phase].append(epoch_loss) 
80
            epoch_dice_score_0 = running_dice_score_class_0/data_sizes[phase]
81
            epoch_dice_score_1 = running_dice_score_class_1/data_sizes[phase]
82
            epoch_dice_score_2 = running_dice_score_class_2/data_sizes[phase]
83
            dice_score_0_hist[phase].append(epoch_dice_score_0)
84
            dice_score_1_hist[phase].append(epoch_dice_score_1)
85
            dice_score_2_hist[phase].append(epoch_dice_score_2)
86
            if verbose or i%10 == 0:
87
                print('Epoch: {}, Phase: {}, epoch loss: {:.4f}, Dice Score (class 0): {:.4f}, Dice Score (class 1): {:.4f},Dice Score (class 2): {:.4f}'.format(i,phase,epoch_loss, epoch_dice_score_0, epoch_dice_score_1, epoch_dice_score_2))
88
                print('-'*10)
89
            
90
        if phase == 'validate' and epoch_loss < best_loss:
91
            best_loss = epoch_loss
92
            best_model_wts = copy.deepcopy(model.state_dict())
93
            torch.save(model,name)
94
            best_dice_cl0 = epoch_dice_score_0
95
            best_dice_cl1 = epoch_dice_score_1
96
            best_dice_cl2 = epoch_dice_score_2
97
    print('-'*50)    
98
    time_elapsed = time.time() - since
99
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
100
    print('Best val dice loss: {:4f}, dice score (class 0): {:.4f}, dice score (class 1): {:.4f},dice score (class 2): {:.4f}'\
101
          .format(best_loss, best_dice_cl0, best_dice_cl1, best_dice_cl2))
102
    
103
    model.load_state_dict(best_model_wts)
104
    
105
    return model, loss_hist, dice_score_0_hist, dice_score_1_hist, dice_score_2_hist
106
107
import time
108
def train_model_patches(model, optimizer,dataloader, data_sizes, batch_size, num_epochs = 100, verbose = False):
109
    since = time.time()
110
    best_loss = np.inf
111
    loss_hist = {'train':[],'validate':[]}
112
    dice_score_0_hist = {'train':[],'validate':[]}
113
    dice_score_1_hist = {'train':[],'validate':[]}
114
    dice_score_2_hist = {'train':[],'validate':[]}
115
    for i in range(num_epochs):
116
        for phase in ['train', 'validate']:
117
            running_loss = 0
118
            running_dice_score_class_0 = 0
119
            running_dice_score_class_1 = 0
120
            running_dice_score_class_2 = 0
121
            
122
            if phase == 'train':
123
                model.train(True)
124
            else:
125
                model.train(False)
126
    
127
            for data in dataloader[phase]:
128
                optimizer.zero_grad()
129
                input, seg,_ = data
130
                input = Variable(input[:,:,1:49,1:49]).cuda()
131
                true = Variable(seg[:,:,1:49,1:49]).cuda()
132
                output = model(input)
133
                loss = dice_loss(true,output)
134
                if phase == 'train':
135
                    loss.backward()
136
                    optimizer.step()
137
                running_loss += loss.data[0] * batch_size
138
                dice_score_batch = dice_score(true,output)
139
                running_dice_score_class_0 += dice_score_batch[0] * batch_size
140
                running_dice_score_class_1 += dice_score_batch[1] * batch_size
141
                running_dice_score_class_2 += dice_score_batch[2] * batch_size
142
            epoch_loss = running_loss/data_sizes[phase]
143
            loss_hist[phase].append(epoch_loss) 
144
            epoch_dice_score_0 = running_dice_score_class_0/data_sizes[phase]
145
            epoch_dice_score_1 = running_dice_score_class_1/data_sizes[phase]
146
            epoch_dice_score_2 = running_dice_score_class_2/data_sizes[phase]
147
            dice_score_0_hist[phase].append(epoch_dice_score_0)
148
            dice_score_1_hist[phase].append(epoch_dice_score_1)
149
            dice_score_2_hist[phase].append(epoch_dice_score_2)
150
            if verbose or i%10 == 0:
151
                print('Epoch: {}, Phase: {}, epoch loss: {:.4f}, Dice Score (class 1): {:.4f}, Dice Score (class 2): {:.4f},Dice Score (class 3): {:.4f}'.format(i,phase,epoch_loss, epoch_dice_score_0, epoch_dice_score_1, epoch_dice_score_2))
152
                print('-'*10)
153
            
154
        if phase == 'validate' and epoch_loss < best_loss:
155
            best_loss = epoch_loss
156
            best_model_wts = model.state_dict() 
157
        
158
        if phase == 'validate':
159
            torch.save(model,'unet_patches_4_epoch{}'.format(i))
160
            
161
    print('-'*50)    
162
    time_elapsed = time.time() - since
163
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
164
    print('Best val dice loss: {:4f}'.format(best_loss))
165
    
166
    model.load_state_dict(best_model_wts)
167
    
168
    return model, loss_hist, dice_score_0_hist, dice_score_1_hist, dice_score_2_hist