Diff of /train.py [000000] .. [38391a]

Switch to unified view

a b/train.py
1
from __future__ import absolute_import, division, print_function
2
3
import argparse
4
from datetime import datetime
5
from os import environ
6
7
import keras.backend as K
8
from keras.datasets import cifar10
9
import keras
10
import numpy as np
11
import pandas as pd
12
import tensorflow as tf
13
from random_eraser import get_random_eraser
14
15
from skimage import io, color, transform, exposure
16
from keras.applications import MobileNet, ResNet50
17
from keras.applications import DenseNet169, InceptionResNetV2,DenseNet201
18
from keras.applications.vgg19 import VGG19
19
from keras.callbacks import (EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard)
20
from keras.layers import Dense, GlobalAveragePooling2D, GlobalMaxPooling2D
21
from keras.metrics import binary_accuracy, binary_crossentropy, kappa_error, kullback_leibler_divergence
22
from keras.models import Model
23
from keras.optimizers import SGD, Adam
24
from keras.preprocessing.image import ImageDataGenerator
25
from sklearn.utils import class_weight
26
from keras.utils.training_utils import multi_gpu_model
27
from custom_layers import *
28
29
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Shut up tensorflow!
30
print("tf : {}".format(tf.__version__))
31
print("keras : {}".format(keras.__version__))
32
print("numpy : {}".format(np.__version__))
33
print("pandas : {}".format(pd.__version__))
34
35
parser = argparse.ArgumentParser(description='Hyperparameters')
36
parser.add_argument('--classes', default=1, type=int)
37
parser.add_argument('--workers', default=4, type=int)
38
parser.add_argument('--epochs', default=30, type=int)
39
parser.add_argument('-b', '--batch-size', default=8, type=int, help='mini-batch size')
40
parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float)
41
parser.add_argument('--lr-wait', default=10, type=int, help='how long to wait on plateu')
42
parser.add_argument('--decay', default=1e-4, type=float)
43
parser.add_argument('--momentum', default=0.9, type=float)
44
parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint')
45
parser.add_argument('--fullretrain', dest='fullretrain', action='store_true', help='retrain all layers of the model')
46
parser.add_argument('--seed', default=1953, type=int, help='random seed')
47
parser.add_argument('--img_channels', default=3, type=int)
48
parser.add_argument('--img_size', default=499, type=int)
49
parser.add_argument('--early_stop', default=20, type=int)
50
51
52
#def preprocess_img():
53
#    def preprocess_img(img):
54
        # Histogram normalization in v channel
55
#         args = parser.parse_args()
56
57
 #        hsv = color.rgb2hsv(img)
58
 #        hsv[:, :, 2] = exposure.equalize_hist(hsv[:, :, 2])
59
 #        img = color.hsv2rgb(hsv)
60
61
            # central square crop
62
#         min_side = min(img.shape[:-1])
63
#         centre = img.shape[0] // 2, img.shape[1] // 2
64
#         img = img[centre[0] - min_side // 2:centre[0] + min_side // 2,
65
#          centre[1] - min_side // 2:centre[1] + min_side // 2,
66
#              :]
67
68
         # rescale to standard size
69
 #        img = transform.resize(img, (args.img_size, args.img_size))
70
71
         # roll color axis to axis 0
72
 #        img = np.rollaxis(img, -1)
73
74
 #        img = img.transpose([2,0,1])
75
 #        img = img.transpose([2,0,1])
76
 #        return img
77
 #   return preprocess_img
78
79
def train(args=None):
80
81
    args = parser.parse_args()
82
83
    img_shape = ( args.img_size, args.img_size, args.img_channels)  # blame theano
84
    now_iso = datetime.now().strftime('%Y-%m-%dT%H:%M:%S%z')
85
86
    #(x_train, y_train), (x_test,y_test) = cifar10.load_data()
87
88
    # We then scale the variable-sized images to 224x224
89
    # We augment .. by applying random lateral inversions and rotations.
90
    train_datagen = ImageDataGenerator(
91
        rescale=1. / 255,
92
        rotation_range=30,
93
#        contrast_stretching=True,
94
#        adaptive_equalization=True,
95
        histogram_equalization=True,
96
#        featurewise_center=True,
97
#        samplewise_center=True,
98
#        featurewise_std_normalization=True,
99
#        samplewise_std_normalization=True,
100
#        channel_shift_range=0.2,
101
#        brightness_range=[-0.3, 0.3],
102
        width_shift_range=0.2,
103
        height_shift_range=0.2,
104
        zoom_range=0.3,
105
        horizontal_flip=True,
106
#        preprocessing_function= get_random_eraser(v_l=0, v_h=1, pixel_level=True)
107
#        preprocessing_function=preprocess_img()
108
        )
109
    #train_datagen.fit(x_train)
110
111
112
    train_generator = train_datagen.flow_from_directory(
113
                     'data/train/',
114
                     shuffle=True,
115
                     target_size=(args.img_size, args.img_size),
116
#                     save_to_dir='data/AUG_ELBOW_HIST',
117
                     class_mode='binary',
118
#                     color_mode='grayscale',
119
#                     interpolation='bicubic',
120
                     batch_size=args.batch_size, )
121
122
    val_datagen = ImageDataGenerator(rescale=1. / 255,
123
#                                       contrast_stretching=True
124
#                                     ,adaptive_equalization=True
125
#                                      histogram_equalization=True
126
                                     )
127
    val_generator = val_datagen.flow_from_directory(
128
        'data/val/',
129
        shuffle=True,  # otherwise we get distorted batch-wise metrics
130
        class_mode='binary',
131
#        color_mode='grayscale',
132
        target_size=(args.img_size, args.img_size),
133
        batch_size=args.batch_size, )
134
#    val_datagen.fit(x_train)
135
136
    classes = len(train_generator.class_indices)
137
    assert classes > 0
138
    assert classes is len(val_generator.class_indices)
139
    n_of_train_samples = train_generator.samples
140
    n_of_val_samples = val_generator.samples
141
142
143
    # Architectures
144
    base_model = InceptionResNetV2(input_shape=img_shape, weights='imagenet', include_top=False)
145
146
147
    x = base_model.output  # Recast classification layer
148
149
    #x = Flatten()(x)  # Uncomment for Resnet based model
150
#    x = GlobalAveragePooling2D(name='predictions_avg_pool')(x)  # comment for RESNET models
151
    x = WildcatPool2d()(x)
152
    # n_classes; softmax for multi-class, sigmoid for binary
153
    x = Dense(args.classes, activation='sigmoid', name='predictions')(x)
154
155
    model = Model(inputs=base_model.input, outputs=x)
156
157
#    model = multi_gpu_model(model, gpus=2)
158
159
    # checkpoints
160
    #
161
    checkpoint = ModelCheckpoint(filepath='./models/InceptionResNetV2_499_NEW_HIST_WC_1.hdf5', verbose=1, save_best_only=True)
162
    early_stop = EarlyStopping(patience=args.early_stop)
163
    tensorboard = TensorBoard(log_dir='./logs/InceptionResNetV2_499_NEW_HIST_WC_1/{}/'.format(now_iso))
164
    reduce_lr = ReduceLROnPlateau(factor=0.03, cooldown=0, patience=args.lr_wait, min_lr=0.1e-6)
165
    callbacks = [checkpoint, tensorboard, reduce_lr]
166
167
    # Calculate class weights
168
    weights = class_weight.compute_class_weight('balanced', np.unique(train_generator.classes), train_generator.classes)
169
    weights = {0: weights[0], 1: weights[1]}
170
    # for layer in base_model.layers:
171
    #     layer.set_trainable = False
172
173
    #print(model.summary())
174
    #for i, layer in enumerate(base_model.layers):
175
    #     print(i, layer.name)
176
    if args.resume:
177
        model.load_weights(args.resume)
178
        for layer in model.layers:
179
            layer.set_trainable = True
180
181
    #if TRAIN_FULL:
182
    #    print("=> retrain all layers of network")
183
    #     for layer in model.layers:
184
    #         set_trainable = True
185
    #else:
186
    #     print("=> retraining only bottleneck and fc layers")
187
    #     import pdb
188
    #     pdb.set_trace()
189
    #     set_trainable = False
190
    #     for layer in base_model.layers:
191
    #         if "block12" in layer.name:  # what block do we want to start unfreezing
192
    #             set_trainable = True
193
    #         if set_trainable:
194
    #             layer.trainable = True
195
    #         else:
196
    #             layer.trainable = False
197
198
# The network is trained end-to-end using Adam with default parameters
199
    model.compile(
200
        optimizer=Adam(lr=args.lr, decay=args.decay),
201
#        optimizer=SGD(lr=args.lr, decay=args.decay,momentum=args.momentum, nesterov=True),
202
        loss=binary_crossentropy,
203
#        loss=kappa_error,
204
        metrics=['accuracy', binary_accuracy], )
205
206
    model_out = model.fit_generator(
207
        train_generator,
208
        steps_per_epoch=n_of_train_samples // args.batch_size,
209
        epochs=args.epochs,
210
        validation_data=val_generator,
211
        validation_steps=n_of_val_samples // args.batch_size,
212
        class_weight=weights,
213
        workers=args.workers,
214
        use_multiprocessing=True,
215
        callbacks=callbacks)
216
217
218
if __name__ == '__main__':
219
    train()