Diff of /train_3d.py [000000] .. [94d9b6]

Switch to unified view

a b/train_3d.py
1
import time
2
import torch
3
import torch.nn as nn
4
from torch.autograd import Variable
5
import torch.optim as optim
6
import torchvision
7
from torchvision import datasets, models
8
from torchvision import transforms as T
9
from torch.utils.data import DataLoader, Dataset
10
import numpy as np
11
import matplotlib.pyplot as plt
12
import os
13
import time
14
import pandas as pd
15
from skimage import io, transform
16
import matplotlib.image as mpimg
17
from PIL import Image
18
from sklearn.metrics import roc_auc_score
19
import torch.nn.functional as F
20
import scipy
21
import random
22
import pickle
23
import scipy.io as sio
24
import itertools
25
from scipy.ndimage.interpolation import shift
26
27
import warnings
28
warnings.filterwarnings("ignore")
29
plt.ion()
30
31
from utils import *
32
33
def train_model_3d(model, optimizer,dataloader, data_sizes, batch_size, num_epochs = 100, verbose = False):
34
    since = time.time()
35
    best_loss = np.inf
36
    loss_hist = {'train':[],'validate':[]}
37
    dice_score_0_hist = {'train':[],'validate':[]}
38
    dice_score_1_hist = {'train':[],'validate':[]}
39
    dice_score_2_hist = {'train':[],'validate':[]}
40
    for i in range(num_epochs):
41
        for phase in ['train', 'validate']:
42
            running_loss = 0
43
            running_dice_score_class_0 = 0
44
            running_dice_score_class_1 = 0
45
            running_dice_score_class_2 = 0
46
47
            if phase == 'train':
48
                model.train(True)
49
            else:
50
                model.train(False)
51
52
            for data in dataloader[phase]:
53
                optimizer.zero_grad()
54
                input, target,_ = data
55
                input = Variable(input).cuda()
56
                true = Variable(torch.transpose(target,2,1).contiguous().view(-1,4,256,256)).cuda()
57
                output = model(input)
58
                output_reshaped = torch.transpose(output,2,1).contiguous().view(-1,4,256,256)
59
                loss = dice_loss(true,output_reshaped,p=2.5)
60
                if phase == 'train':
61
                    loss.backward()
62
                    optimizer.step()
63
                running_loss += loss.data[0] * batch_size
64
                dice_score_batch = dice_score(true,output_reshaped)
65
                running_dice_score_class_0 += dice_score_batch[0] * batch_size
66
                running_dice_score_class_1 += dice_score_batch[1] * batch_size
67
                running_dice_score_class_2 += dice_score_batch[2] * batch_size
68
            epoch_loss = running_loss/data_sizes[phase]
69
            loss_hist[phase].append(epoch_loss) 
70
            epoch_dice_score_0 = running_dice_score_class_0/data_sizes[phase]
71
            epoch_dice_score_1 = running_dice_score_class_1/data_sizes[phase]
72
            epoch_dice_score_2 = running_dice_score_class_2/data_sizes[phase]
73
            dice_score_0_hist[phase].append(epoch_dice_score_0)
74
            dice_score_1_hist[phase].append(epoch_dice_score_1)
75
            dice_score_2_hist[phase].append(epoch_dice_score_2)
76
            if verbose or i%10 == 0:
77
                print('Epoch: {}, Phase: {}, epoch loss: {:.4f}, Dice Score (class 0): {:.4f}, Dice Score (class 1): {:.4f},Dice Score (class 2): {:.4f}'.format(i,phase,epoch_loss, epoch_dice_score_0, epoch_dice_score_1, epoch_dice_score_2))
78
                print('-'*10)
79
80
        if phase == 'validate' and epoch_loss < best_loss:
81
            best_loss = epoch_loss
82
            best_model_wts = model.state_dict() 
83
84
    print('-'*50)    
85
    time_elapsed = time.time() - since
86
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
87
    print('Best val dice loss: {:4f}'.format(best_loss))
88
89
    model.load_state_dict(best_model_wts)
90
91
    return model, loss_hist, dice_score_0_hist, dice_score_1_hist, dice_score_2_hist