Diff of /code/train-generator.py [000000] .. [ebf7be]

Switch to unified view

a b/code/train-generator.py
1
"""
2
Purpose: train a machine learning segmenter that can segment out the nodules on a given 2D patient CT scan slice
3
Note:
4
- this will train from scratch, with no preloaded weights
5
- weights are saved to unet.hdf5 in the specified output folder
6
"""
7
8
from __future__ import print_function
9
import os
10
from glob import glob
11
import numpy as np
12
import matplotlib.pyplot as plt
13
from keras.models import Model
14
from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D
15
from keras.optimizers import Adam
16
from keras.optimizers import SGD
17
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
18
from keras import backend as K
19
20
21
TRAIN_PATH = '/home/ubuntu/data/train_pre/'
22
VAL_PATH = '/home/ubuntu/data/val_pre/'
23
TEST_PATH = '/home/ubuntu/data/test_pre/'
24
IMG_ROWS = 512
25
IMG_COLS = 512
26
27
SMOOTH = 1.
28
29
K.set_image_dim_ordering('th')  # Theano dimension ordering in this code
30
31
def dice_coef(y_true, y_pred):
32
    y_true_f = K.flatten(y_true)
33
    y_pred_f = K.flatten(y_pred)
34
    intersection = K.sum(y_true_f * y_pred_f)
35
    return (2. * intersection + SMOOTH) / (K.sum(y_true_f) + K.sum(y_pred_f) + SMOOTH)
36
37
def dice_coef_loss(y_true, y_pred):
38
    return -dice_coef(y_true, y_pred)
39
40
def dice_coef_np(y_true,y_pred):
41
    y_true_f = y_true.flatten()
42
    y_pred_f = y_pred.flatten()
43
    intersection = np.sum(y_true_f * y_pred_f)
44
    return (2. * intersection + SMOOTH) / (np.sum(y_true_f) + np.sum(y_pred_f) + SMOOTH)
45
46
def get_unet():
47
    inputs = Input((1,IMG_ROWS, IMG_COLS))
48
    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs)
49
    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1)
50
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
51
52
    conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool1)
53
    conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2)
54
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
55
56
    conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2)
57
    conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3)
58
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
59
60
    conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(pool3)
61
    conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv4)
62
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
63
64
    conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(pool4)
65
    conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(conv5)
66
67
    up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)
68
    conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(up6)
69
    conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv6)
70
71
    up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1)
72
    conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(up7)
73
    conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv7)
74
75
    up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)
76
    conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up8)
77
    conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv8)
78
79
    up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1)
80
    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up9)
81
    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9)
82
83
    conv10 = Convolution2D(1, 1, 1, activation='sigmoid')(conv9)
84
85
    model = Model(input=inputs, output=conv10)
86
87
    model.compile(optimizer=Adam(lr=1.0e-5), loss=dice_coef_loss, metrics=[dice_coef])
88
89
    return model
90
91
def generator(path,batch_size):
92
93
    lung_mask_list = glob(path + 'final_lung_mask_*.npy')
94
    nodule_mask_list = glob(path + 'final_nodule_mask_*.npy')
95
    lung_mask_list.sort()
96
    nodule_mask_list.sort()
97
98
    flag = 0
99
    start = 0
100
    cnt = 0
101
102
    while (1):
103
        
104
        for i in range(len(lung_mask_list)):
105
            lung_file = lung_mask_list[i]
106
            nodule_file = nodule_mask_list[i]
107
            lung = np.load(lung_file)
108
            nodule = np.load(nodule_file)
109
110
            if(flag):
111
                lung_train = np.concatenate((lung_train,lung[0:batch_supply]),axis=0).reshape([batch_size,1,512,512])
112
                nodule_train = np.concatenate((nodule_train,nodule[0:batch_supply]),axis=0).reshape([batch_size,1,512,512])
113
                yield (lung_train / 255.0, nodule_train / 255.0)
114
                start = batch_supply
115
                flag = 0
116
            while(start + batch_size < len(lung)):
117
                lung_train = lung[start:start+batch_size].reshape([batch_size,1,512,512])
118
                nodule_train = nodule[start:start+batch_size].reshape([batch_size,1,512,512])
119
                yield (lung_train / 255.0, nodule_train / 255.0)
120
                start += batch_size
121
            if(start + batch_size == len(lung)):
122
                lung_train = lung[start:start+batch_size].reshape([batch_size,1,512,512])
123
                nodule_train = nodule[start:start+batch_size].reshape([batch_size,1,512,512])
124
                yield (lung_train / 255.0, nodule_train / 255.0)
125
                flag = 0
126
                start = 0
127
            else:
128
                lung_train = lung[start:]
129
                nodule_train = nodule[start:]
130
                batch_supply = batch_size - (len(lung)-start)
131
                start = 0
132
                flag = 1
133
134
135
def train_generator(batch_size):
136
    model = get_unet()
137
    print('model compileover ...')
138
    model.fit_generator(generator(TRAIN_PATH,batch_size),steps_per_epoch = 4540, epochs = 2, verbose = 1, validation_data=generator(VAL_PATH,batch_size),validation_steps=650)
139
    
140
    data_pred = np.load(TEST_PATH+'final_lung_mask_9.npy').reshape([1007,1,512,512])
141
    nodule_true = np.load(TEST_PATH+'final_nodule_mask_9.npy').reshape([1007,1,512,512])
142
143
    print(model.evaluate(data_pred,nodule_true,batch_size=2,verbose=1))
144
145
146
    
147
if __name__ == '__main__':
148
    batch_size = 2
149
    train_generator(batch_size)