Diff of /train_cnn.py [000000] .. [75e50a]

Switch to unified view

a b/train_cnn.py
1
import keras
2
from keras.preprocessing.image import ImageDataGenerator
3
from keras.models import Sequential
4
from keras.layers import Dense, Dropout, Activation, Flatten, Reshape
5
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
6
from keras.preprocessing.image import ImageDataGenerator
7
from keras import regularizers
8
from keras.losses import mean_squared_error
9
import glob
10
import matplotlib.patches as patches
11
import json
12
import numpy as np
13
from matplotlib.path import Path
14
import dicom
15
import cv2
16
17
from utils import *
18
19
def create_model(activation, input_shape=(64, 64)):
20
    """
21
    Simple convnet model : one convolution, one average pooling and one fully connected layer
22
    :param activation: None if nothing passed, e.g : ReLu, tanh, etc.
23
    :return: Keras model
24
    """
25
    model = Sequential()
26
    model.add(Conv2D(100, (11,11), activation=activation, padding='valid', strides=(1, 1), input_shape=(input_shape[0], input_shape[1], 1)))
27
    model.add(AveragePooling2D((6,6)))
28
    model.add(Reshape([-1, 8100]))
29
    model.add(Dense(1024, activation='sigmoid', kernel_regularizer=regularizers.l2(0.0001)))
30
    model.add(Reshape([-1, 32, 32]))
31
    return model
32
33
def create_model_maxpooling(activation, input_shape=(64, 64)):
34
    """
35
    Simple convnet model with max pooling: one convolution, one max pooling and one fully connected layer
36
    :param activation: None if nothing passed, e.g : ReLu, tanh, etc.
37
    :return: Keras model
38
    """
39
    model = Sequential()
40
    model.add(Conv2D(100, (11,11), activation=activation, padding='valid', strides=(1, 1), input_shape=(input_shape[0], input_shape[1], 1)))
41
    model.add(MaxPooling2D((6,6)))
42
    model.add(Reshape([-1, 8100]))
43
    model.add(Dense(1024, activation='sigmoid', kernel_regularizer=regularizers.l2(0.0001)))
44
    model.add(Reshape([-1, 32, 32]))
45
    return model
46
47
def create_model_larger(activation, input_shape=(64, 64)):
48
    """
49
    Larger (more filters) convnet model : one convolution, one average pooling and one fully connected layer:
50
    :param activation: None if nothing passed, e.g : ReLu, tanh, etc. 
51
    :return: Keras model
52
    """
53
    model = Sequential()
54
    model.add(Conv2D(200, (11,11), activation=activation, padding='valid', strides=(1, 1), input_shape=(input_shape[0], input_shape[1], 1)))
55
    model.add(AveragePooling2D((6,6)))
56
    model.add(Reshape([-1, 16200]))
57
    model.add(Dense(1024, activation='sigmoid', kernel_regularizer=regularizers.l2(0.0001)))
58
    model.add(Reshape([-1, 32, 32]))
59
    return model
60
61
def create_model_deeper(activation, input_shape=(64, 64)):
62
    """
63
    Deeper convnet model : two convolutions, two average pooling and one fully connected layer:
64
    :param activation: None if nothing passed, e.g : ReLu, tanh, etc.
65
    :return: Keras model
66
    """
67
    model = Sequential()
68
    model.add(Conv2D(64, (11,11), activation=activation, padding='valid', strides=(1, 1), input_shape=(input_shape[0], input_shape[1], 1)))
69
    model.add(AveragePooling2D((2,2)))
70
    model.add(Conv2D(128, (10, 10), activation=activation, padding='valid', strides=(1, 1)))
71
    model.add(AveragePooling2D((2,2)))
72
    model.add(Reshape([-1, 128*9*9]))
73
    model.add(Dense(1024, activation='sigmoid', kernel_regularizer=regularizers.l2(0.0001)))
74
    model.add(Reshape([-1, 32, 32]))
75
    return model
76
77
def create_model_full(activation, input_shape=(64, 64)):
78
    model = Sequential()
79
    model.add(Conv2D(64, (11,11), activation=activation, padding='valid', strides=(1, 1), input_shape=(input_shape[0], input_shape[1], 1)))
80
    model.add(MaxPooling2D((2,2)))
81
    model.add(Conv2D(128, (10, 10), activation=activation, padding='valid', strides=(1, 1)))
82
    model.add(MaxPooling2D((2,2)))
83
    model.add(Reshape([-1, 128*9*9]))
84
    model.add(Dense(1024, activation='sigmoid', kernel_regularizer=regularizers.l2(0.0001)))
85
    model.add(Reshape([-1, 32, 32]))
86
    return model
87
88
def training(m, X, Y, verbose, batch_size=16, epochs=20, data_augm=False):
89
    """
90
    Training CNN with the possibility to use data augmentation
91
    :param m: Keras model
92
    :param epochs: number of epochs
93
    :param X: training pictures
94
    :param Y: training binary ROI mask
95
    :return: history
96
    """
97
    if data_augm:
98
        datagen = ImageDataGenerator(
99
            featurewise_center=False,  # set input mean to 0 over the dataset
100
            samplewise_center=False,  # set each sample mean to 0
101
            featurewise_std_normalization=False,  # divide inputs by std of the dataset
102
            samplewise_std_normalization=False,  # divide each input by its std
103
            zca_whitening=False,  # apply ZCA whitening
104
            rotation_range=50,  # randomly rotate images in the range (degrees, 0 to 180)
105
            width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
106
            height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)
107
            horizontal_flip=True,  # randomly flip images
108
            vertical_flip=False) 
109
        datagen.fit(X)
110
        history = m.fit_generator(datagen.flow(X, Y,
111
                                    batch_size=batch_size),
112
                                    steps_per_epoch=X.shape[0] // batch_size,
113
                                    epochs=epochs,
114
                                    verbose=verbose)         
115
    else:
116
        history = m.fit(X, Y, batch_size=batch_size, epochs=epochs, verbose=verbose)
117
    return history, m
118
119
def run(model='simple', X_to_pred=None, history=False, verbose=0, activation=None, epochs=20, data_augm=False):
120
    """
121
    Full pipeline for CNN: load the dataset, train the model and predict ROIs
122
    :param model: choice between different models e.g simple, larger, deeper, maxpooling
123
    :param activation: None if nothing passed, e.g : ReLu, tanh, etc.
124
    :param epochs: number of epochs
125
    :param X_to_pred: input for predictions after training (X_train if not specified)
126
    :param verbose: int for verbose
127
    :return: X, X_fullsize, Y, y_pred, h (if history boolean passed)
128
    """
129
    X, X_fullsize, Y, contour_mask = create_dataset()
130
    if model == 'simple':
131
        m = create_model(activation=activation)
132
    elif model == 'larger':
133
        m = create_model_larger(activation=activation)
134
    elif model == 'deeper':
135
        m = create_model_deeper(activation=activation)
136
    elif model == 'maxpooling':
137
        m = create_model_maxpooling(activation=activation)
138
    elif model =='full':
139
        m = create_model_full(activation=activation)
140
141
    m.compile(loss='mean_squared_error',
142
              optimizer='adam',
143
              metrics=['accuracy'])
144
    if verbose > 0:
145
        print('Size for each layer :\nLayer, Input Size, Output Size')
146
        for p in m.layers:
147
            print(p.name.title(), p.input_shape, p.output_shape)
148
    h, m = training(m, X, Y, verbose=verbose, batch_size=16, epochs=epochs, data_augm=data_augm)
149
150
    if not X_to_pred:
151
        X_to_pred = X
152
    y_pred = m.predict(X_to_pred, batch_size=16)
153
    
154
    if history:
155
        return X, X_fullsize, Y, contour_mask, y_pred, h, m
156
    else:
157
        return X, X_fullsize, Y, contour_mask, y_pred, m
158
159
def inference(model):
160
    X_test, X_fullsize_test, Y_test, contour_mask_test = create_dataset(n_set='test')
161
    y_pred = model.predict(X_test, batch_size=16)
162
    return X_test, X_fullsize_test, Y_test, contour_mask_test, y_pred