--- a +++ b/train.py @@ -0,0 +1,219 @@ +from __future__ import absolute_import, division, print_function + +import argparse +from datetime import datetime +from os import environ + +import keras.backend as K +from keras.datasets import cifar10 +import keras +import numpy as np +import pandas as pd +import tensorflow as tf +from random_eraser import get_random_eraser + +from skimage import io, color, transform, exposure +from keras.applications import MobileNet, ResNet50 +from keras.applications import DenseNet169, InceptionResNetV2,DenseNet201 +from keras.applications.vgg19 import VGG19 +from keras.callbacks import (EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard) +from keras.layers import Dense, GlobalAveragePooling2D, GlobalMaxPooling2D +from keras.metrics import binary_accuracy, binary_crossentropy, kappa_error, kullback_leibler_divergence +from keras.models import Model +from keras.optimizers import SGD, Adam +from keras.preprocessing.image import ImageDataGenerator +from sklearn.utils import class_weight +from keras.utils.training_utils import multi_gpu_model +from custom_layers import * + +environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Shut up tensorflow! +print("tf : {}".format(tf.__version__)) +print("keras : {}".format(keras.__version__)) +print("numpy : {}".format(np.__version__)) +print("pandas : {}".format(pd.__version__)) + +parser = argparse.ArgumentParser(description='Hyperparameters') +parser.add_argument('--classes', default=1, type=int) +parser.add_argument('--workers', default=4, type=int) +parser.add_argument('--epochs', default=30, type=int) +parser.add_argument('-b', '--batch-size', default=8, type=int, help='mini-batch size') +parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float) +parser.add_argument('--lr-wait', default=10, type=int, help='how long to wait on plateu') +parser.add_argument('--decay', default=1e-4, type=float) +parser.add_argument('--momentum', default=0.9, type=float) +parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint') +parser.add_argument('--fullretrain', dest='fullretrain', action='store_true', help='retrain all layers of the model') +parser.add_argument('--seed', default=1953, type=int, help='random seed') +parser.add_argument('--img_channels', default=3, type=int) +parser.add_argument('--img_size', default=499, type=int) +parser.add_argument('--early_stop', default=20, type=int) + + +#def preprocess_img(): +# def preprocess_img(img): + # Histogram normalization in v channel +# args = parser.parse_args() + + # hsv = color.rgb2hsv(img) + # hsv[:, :, 2] = exposure.equalize_hist(hsv[:, :, 2]) + # img = color.hsv2rgb(hsv) + + # central square crop +# min_side = min(img.shape[:-1]) +# centre = img.shape[0] // 2, img.shape[1] // 2 +# img = img[centre[0] - min_side // 2:centre[0] + min_side // 2, +# centre[1] - min_side // 2:centre[1] + min_side // 2, +# :] + + # rescale to standard size + # img = transform.resize(img, (args.img_size, args.img_size)) + + # roll color axis to axis 0 + # img = np.rollaxis(img, -1) + + # img = img.transpose([2,0,1]) + # img = img.transpose([2,0,1]) + # return img + # return preprocess_img + +def train(args=None): + + args = parser.parse_args() + + img_shape = ( args.img_size, args.img_size, args.img_channels) # blame theano + now_iso = datetime.now().strftime('%Y-%m-%dT%H:%M:%S%z') + + #(x_train, y_train), (x_test,y_test) = cifar10.load_data() + + # We then scale the variable-sized images to 224x224 + # We augment .. by applying random lateral inversions and rotations. + train_datagen = ImageDataGenerator( + rescale=1. / 255, + rotation_range=30, +# contrast_stretching=True, +# adaptive_equalization=True, + histogram_equalization=True, +# featurewise_center=True, +# samplewise_center=True, +# featurewise_std_normalization=True, +# samplewise_std_normalization=True, +# channel_shift_range=0.2, +# brightness_range=[-0.3, 0.3], + width_shift_range=0.2, + height_shift_range=0.2, + zoom_range=0.3, + horizontal_flip=True, +# preprocessing_function= get_random_eraser(v_l=0, v_h=1, pixel_level=True) +# preprocessing_function=preprocess_img() + ) + #train_datagen.fit(x_train) + + + train_generator = train_datagen.flow_from_directory( + 'data/train/', + shuffle=True, + target_size=(args.img_size, args.img_size), +# save_to_dir='data/AUG_ELBOW_HIST', + class_mode='binary', +# color_mode='grayscale', +# interpolation='bicubic', + batch_size=args.batch_size, ) + + val_datagen = ImageDataGenerator(rescale=1. / 255, +# contrast_stretching=True +# ,adaptive_equalization=True +# histogram_equalization=True + ) + val_generator = val_datagen.flow_from_directory( + 'data/val/', + shuffle=True, # otherwise we get distorted batch-wise metrics + class_mode='binary', +# color_mode='grayscale', + target_size=(args.img_size, args.img_size), + batch_size=args.batch_size, ) +# val_datagen.fit(x_train) + + classes = len(train_generator.class_indices) + assert classes > 0 + assert classes is len(val_generator.class_indices) + n_of_train_samples = train_generator.samples + n_of_val_samples = val_generator.samples + + + # Architectures + base_model = InceptionResNetV2(input_shape=img_shape, weights='imagenet', include_top=False) + + + x = base_model.output # Recast classification layer + + #x = Flatten()(x) # Uncomment for Resnet based model +# x = GlobalAveragePooling2D(name='predictions_avg_pool')(x) # comment for RESNET models + x = WildcatPool2d()(x) + # n_classes; softmax for multi-class, sigmoid for binary + x = Dense(args.classes, activation='sigmoid', name='predictions')(x) + + model = Model(inputs=base_model.input, outputs=x) + +# model = multi_gpu_model(model, gpus=2) + + # checkpoints + # + checkpoint = ModelCheckpoint(filepath='./models/InceptionResNetV2_499_NEW_HIST_WC_1.hdf5', verbose=1, save_best_only=True) + early_stop = EarlyStopping(patience=args.early_stop) + tensorboard = TensorBoard(log_dir='./logs/InceptionResNetV2_499_NEW_HIST_WC_1/{}/'.format(now_iso)) + reduce_lr = ReduceLROnPlateau(factor=0.03, cooldown=0, patience=args.lr_wait, min_lr=0.1e-6) + callbacks = [checkpoint, tensorboard, reduce_lr] + + # Calculate class weights + weights = class_weight.compute_class_weight('balanced', np.unique(train_generator.classes), train_generator.classes) + weights = {0: weights[0], 1: weights[1]} + # for layer in base_model.layers: + # layer.set_trainable = False + + #print(model.summary()) + #for i, layer in enumerate(base_model.layers): + # print(i, layer.name) + if args.resume: + model.load_weights(args.resume) + for layer in model.layers: + layer.set_trainable = True + + #if TRAIN_FULL: + # print("=> retrain all layers of network") + # for layer in model.layers: + # set_trainable = True + #else: + # print("=> retraining only bottleneck and fc layers") + # import pdb + # pdb.set_trace() + # set_trainable = False + # for layer in base_model.layers: + # if "block12" in layer.name: # what block do we want to start unfreezing + # set_trainable = True + # if set_trainable: + # layer.trainable = True + # else: + # layer.trainable = False + +# The network is trained end-to-end using Adam with default parameters + model.compile( + optimizer=Adam(lr=args.lr, decay=args.decay), +# optimizer=SGD(lr=args.lr, decay=args.decay,momentum=args.momentum, nesterov=True), + loss=binary_crossentropy, +# loss=kappa_error, + metrics=['accuracy', binary_accuracy], ) + + model_out = model.fit_generator( + train_generator, + steps_per_epoch=n_of_train_samples // args.batch_size, + epochs=args.epochs, + validation_data=val_generator, + validation_steps=n_of_val_samples // args.batch_size, + class_weight=weights, + workers=args.workers, + use_multiprocessing=True, + callbacks=callbacks) + + +if __name__ == '__main__': + train()