a b/Classification/train-classifier.py
1
import tensorflow as tf
2
import numpy as np
3
import itertools
4
import sys
5
import os
6
7
from tensorflow.keras import layers
8
from tensorflow.keras import backend as K
9
10
# for saving models and csvlogger
11
import datetime
12
13
print("TF version: ", tf.version.VERSION)
14
15
tf.keras.backend.clear_session()  # For easy reset of notebook state.
16
17
#  config = tf.ConfigProto()
18
#  config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
19
#  sess = tf.Session(config=config)
20
#  K.set_session(sess)  # set this TensorFlow session as the default session for Keras
21
22
# dimensions of our images.
23
img_width, img_height = 150, 150
24
# IMAGE_SIZE    = (100, 100)
25
# CROP_LENGTH   = 84
26
27
if len(sys.argv) != 3:
28
    print('Error: pls provide train path and validation path')
29
    exit(0)
30
31
32
train_data_dir = sys.argv[1]
33
validation_data_dir = sys.argv[2]
34
35
36
nb_classes = 4
37
nb_train_samples = 0
38
nb_validation_samples = 0
39
nb_sample_per_class = []
40
nb_val_sample_per_class = []
41
42
folders = ['1', '2', '3', '4']
43
for folder in folders:
44
    num_tr = len(os.listdir(os.path.join(train_data_dir, folder)))
45
    nb_train_samples += num_tr
46
    nb_sample_per_class.append(num_tr)
47
48
for folder in folders:
49
    num_val = len(os.listdir(os.path.join(validation_data_dir, folder)))
50
    nb_validation_samples += num_val
51
    nb_val_sample_per_class.append(num_val)
52
53
# data_folder = train_data_dir.split(os.sep)
54
# data_folder = [e for e in data_folder if e != '']
55
# data_folder = data_folder[-2]
56
# print(data_folder)
57
print("\nnb_train_samples: ", nb_train_samples)
58
print("\nnb_validation_samples: ", nb_validation_samples)
59
print("\nnb_sample_per_class: ", nb_sample_per_class)
60
print("\nnb_val_sample_per_class: ", nb_val_sample_per_class)
61
print("--")
62
63
epochs = 100
64
batch_size = 128
65
66
from functools import partial, update_wrapper
67
68
def wrapped_partial(func, *args, **kwargs):
69
    partial_func = partial(func, *args, **kwargs)
70
    update_wrapper(partial_func, func)
71
    return partial_func
72
73
def w_categorical_crossentropy(y_true, y_pred, weights):
74
    nb_cl = len(weights)
75
    final_mask = K.zeros_like(y_pred[:, 0])
76
    y_pred_max = K.max(y_pred, axis=1)
77
    y_pred_max = K.expand_dims(y_pred_max, 1)
78
    y_pred_max_mat = K.equal(y_pred, y_pred_max)
79
    for c_p, c_t in itertools.product(range(nb_cl), range(nb_cl)):
80
        final_mask += (K.cast(weights[c_t, c_p],K.floatx()) * K.cast(y_pred_max_mat[:, c_p] ,K.floatx())* K.cast(y_true[:, c_t],K.floatx()))
81
    return K.categorical_crossentropy(y_pred, y_true) * final_mask
82
83
w_array = np.ones((4,4))
84
w_array[0, 1] = 1
85
w_array[0, 2] = 1
86
w_array[0, 3] = 1
87
w_array[1, 0] = float(nb_sample_per_class[0])/float(nb_sample_per_class[1])
88
w_array[1, 2] = float(nb_sample_per_class[0])/float(nb_sample_per_class[1])
89
w_array[1, 3] = float(nb_sample_per_class[0])/float(nb_sample_per_class[1])
90
w_array[2, 0] = float(nb_sample_per_class[0])/float(nb_sample_per_class[2])
91
w_array[2, 1] = float(nb_sample_per_class[0])/float(nb_sample_per_class[2])
92
w_array[2, 3] = float(nb_sample_per_class[0])/float(nb_sample_per_class[2])
93
w_array[3, 0] = float(nb_sample_per_class[0])/float(nb_sample_per_class[3])
94
w_array[3, 1] = float(nb_sample_per_class[0])/float(nb_sample_per_class[3])
95
w_array[3, 2] = float(nb_sample_per_class[0])/float(nb_sample_per_class[3])
96
97
ncce = partial(w_categorical_crossentropy, weights=w_array)
98
ncce.__name__ = 'w_categorical_crossentropy'
99
100
# sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
101
102
103
#  """ # RSNA version custom small model
104
model = tf.keras.Sequential()
105
model.add(layers.Conv2D(32, (3, 3), input_shape=(img_height, img_width, 3), name='conv1'))
106
model.add(layers.Activation('relu'))
107
model.add(layers.BatchNormalization())
108
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
109
110
model.add(layers.Dropout(0.2))
111
112
model.add(layers.Conv2D(64, (3, 3)))
113
model.add(layers.Activation('relu'))
114
model.add(layers.BatchNormalization())
115
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
116
117
model.add(layers.Conv2D(64, (3, 3)))
118
model.add(layers.Activation('relu'))
119
model.add(layers.BatchNormalization())
120
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
121
122
model.add(layers.Conv2D(64, (3, 3)))
123
model.add(layers.Activation('relu'))
124
model.add(layers.BatchNormalization())
125
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
126
127
model.add(layers.Conv2D(32, (3, 3), padding='same'))
128
model.add(layers.Activation('relu'))
129
model.add(layers.BatchNormalization())
130
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
131
132
model.add(layers.Conv2D(32, (3, 3), padding='same'))
133
model.add(layers.Activation('relu'))
134
model.add(layers.BatchNormalization())
135
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
136
137
model.add(layers.Dropout(0.2))
138
139
model.add(layers.Flatten())
140
141
model.add(layers.Dense(512))
142
model.add(layers.Activation('relu'))
143
model.add(layers.Dropout(0.2))
144
145
model.add(layers.Dense(nb_classes, name='output'))
146
model.add(layers.Activation('softmax'))
147
#  """
148
149
# model.compile(loss=ncce,
150
#               optimizer=sgd,
151
#               metrics=['accuracy'])
152
153
154
# base_model = InceptionV3(include_top=False, weights='imagenet', input_shape=(img_height, img_width, 3))
155
156
# x = base_model.output
157
# x = GlobalAveragePooling2D()(x)
158
# # let's add a fully-connected layer
159
# x = Dense(1024, activation='relu')(x)
160
# # and a logistic layer -- let's say we have 200 classes
161
# predictions = Dense(4, activation='softmax')(x)
162
163
# model = Model(inputs=base_model.input, outputs=predictions)
164
165
166
model.compile(loss=ncce,
167
              optimizer=tf.keras.optimizers.Adam(),
168
              metrics=['accuracy'])
169
170
# Display the model's architecture
171
model.summary()
172
173
# # random crop patch from:
174
# # https://mc.ai/extending-keras-imagedatagenerator-to-support-random-cropping/
175
# def random_crop(img, random_crop_size):
176
#     # Note: image_data_format is 'channel_last'
177
#     assert img.shape[2] == 3
178
#     height, width = img.shape[0], img.shape[1]
179
#     dy, dx = random_crop_size
180
#     x = np.random.randint(0, width - dx + 1)
181
#     y = np.random.randint(0, height - dy + 1)
182
#     return img[y:(y+dy), x:(x+dx), :]
183
184
185
# def crop_generator(batches, crop_length):
186
#     """Take as input a Keras ImageGen (Iterator) and generate random
187
#     crops from the image batches generated by the original iterator.
188
#     """
189
#     while True:
190
#         batch_x, batch_y = next(batches)
191
#         batch_crops = np.zeros((batch_x.shape[0], crop_length, crop_length, 3))
192
#         for i in range(batch_x.shape[0]):
193
#             batch_crops[i] = random_crop(batch_x[i], (crop_length, crop_length))
194
#         yield (batch_crops, batch_y)
195
196
197
# this is the augmentation configuration we will use for training
198
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
199
    rescale=1. / 255,
200
    shear_range=0.1,
201
    zoom_range=0.1,
202
    #  horizontal_flip=True,
203
    rotation_range=5,
204
    #  width_shift_range=0.01,
205
    #  height_shift_range=0.01,
206
    #  brightness_range=[0.2, 1.0],
207
    )
208
209
# this is the augmentation configuration we will use for testing:
210
# only rescaling
211
test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255)
212
213
# flow_from_directory will print:
214
# Found xxx images belonging to xxx classes
215
train_generator = train_datagen.flow_from_directory(
216
    train_data_dir,
217
    target_size=(img_height, img_width),
218
    batch_size=batch_size,
219
    class_mode='categorical')
220
221
validation_generator = test_datagen.flow_from_directory(
222
    validation_data_dir,
223
    target_size=(img_height, img_width),
224
    batch_size=batch_size,
225
    class_mode='categorical')
226
227
# iterator that returns image_batch, label_batch pairs
228
for image_batch, label_batch in train_generator:
229
    print("Image batch shape: ", image_batch.shape)
230
    print("Label batch shape: ", label_batch.shape)
231
    break
232
for image_batch, label_batch in validation_generator:
233
    print("Image batch shape: ", image_batch.shape)
234
    print("Label batch shape: ", label_batch.shape)
235
    break
236
237
238
# # crop from 100 to CROP_LENGTH = 84
239
# train_crops = crop_generator(train_batches, CROP_LENGTH)
240
241
########################
242
# callbacks setting
243
########################
244
# for modelcheckpoint:
245
# if True, then only the model's weights will
246
# be saved (model.save_weights(filepath)),
247
# else the full model is saved (model.save(filepath)).
248
249
# for ModelCheckpoint and model saves
250
# Saving everything into a single archive in the TensorFlow SavedModel
251
# format (or in the older Keras H5 format). This is the standard practice.
252
formatted_time = datetime.datetime.now().strftime("%m%d-%H%M")
253
save_model_dir = "Axial_center_resnetscale150V1_150x150bat128_6LDropout_Date{}".format(formatted_time)
254
255
if not os.path.exists(save_model_dir):
256
    print(save_model_dir, " will be created")
257
    os.makedirs(save_model_dir)
258
259
# store the model json as:
260
store_model_json_name = "axial-center-Date{}.json".format(formatted_time)
261
# store model
262
model_json = model.to_json()
263
model_json_path = os.path.join(save_model_dir, store_model_json_name)
264
with open(model_json_path, "w") as json_file:
265
    json_file.write(model_json)
266
267
checkpoint_filepath = os.path.join(
268
    save_model_dir,
269
    "{0}_Ep{{epoch:02d}}_ValAcc{{val_acc:.3f}}_ValLoss{{val_loss:.2f}}.h5"
270
    .format(save_model_dir)
271
)
272
273
callbacks = [
274
    # tf.keras.callbacks.EarlyStopping(
275
    #     monitor='val_loss',
276
    #     patience=15,
277
    #     verbose=1),
278
    tf.keras.callbacks.ModelCheckpoint(
279
        filepath=checkpoint_filepath,
280
        monitor='val_acc',
281
        save_best_only=False,
282
        save_weights_only=True,
283
        verbose=1,
284
        save_freq="epoch"),
285
    tf.keras.callbacks.ReduceLROnPlateau(
286
        monitor='val_loss',
287
        factor=0.5, # usually 0.1
288
        patience=10,
289
        verbose=1),
290
    tf.keras.callbacks.CSVLogger(
291
        filename=os.path.join(
292
            save_model_dir,
293
            '{}.csv'.format(save_model_dir)
294
        ),
295
        append=False,
296
        separator=','),
297
]
298
299
300
train_steps = np.ceil(nb_train_samples / batch_size)
301
print("len train_generator: ", len(train_generator))
302
303
val_steps = np.ceil(nb_validation_samples / batch_size)
304
print("len validation_generator: ", len(validation_generator))
305
306
train_history = model.fit(
307
    train_generator,
308
    steps_per_epoch=train_steps,
309
    epochs=epochs,
310
    validation_data=validation_generator,
311
    validation_steps=val_steps,
312
    callbacks=callbacks)