--- a +++ b/model.py @@ -0,0 +1,91 @@ +from keras.layers import Dense, Flatten +from keras.models import Model +from keras.applications.vgg16 import VGG16 +from keras.preprocessing.image import ImageDataGenerator + +from glob import glob +import matplotlib.pyplot as plt + +# re-size all the images to this +IMAGE_SIZE = [224, 224] + +train_path = 'dataset/TRAIN' +valid_path = 'dataset/TEST' + +# add preprocessing layer to the front of VGG +vgg = VGG16(input_shape=IMAGE_SIZE + [3], weights='imagenet', include_top=False) + +# don't train existing weights +for layer in vgg.layers: + layer.trainable = False + + +#useful for getting number of classes +folders = glob('dataset/TRAIN/*') + + +# our layers - you can add more if you want +x = Flatten()(vgg.output) + +#add the sigmoid as the activation function +prediction = Dense(1, activation='sigmoid')(x) + +# create a model object +model = Model(inputs=vgg.input, outputs=prediction) + +# view the structure of the model +model.summary() + +# tell the model what cost and optimization method to use +model.compile( + loss='binary_crossentropy', + optimizer='adam', + metrics=['accuracy'] +) + + +train_datagen = ImageDataGenerator(rescale = 1./255, + shear_range = 0.2, + zoom_range = 0.2, + horizontal_flip = True) + +test_datagen = ImageDataGenerator(rescale = 1./255) + +training_set = train_datagen.flow_from_directory('dataset/TRAIN', + target_size = (224, 224), + batch_size = 64, + class_mode = 'binary') + +test_set = test_datagen.flow_from_directory('dataset/TEST', + target_size = (224, 224), + batch_size = 64, + class_mode = 'binary') + +# see which class represents 1 and which represents 0 +training_set.class_indices + +# fit the model +r = model.fit_generator( + training_set, + validation_data=test_set, + epochs=4, + steps_per_epoch=len(training_set), + validation_steps=len(test_set) +) +# loss plots +plt.plot(r.history['loss'], label='train loss') +plt.plot(r.history['val_loss'], label='val loss') +plt.legend() +plt.show() +plt.savefig('LossVal_loss') + +# accuracy plots +plt.plot(r.history['accuracy'], label='train acc') +plt.plot(r.history['val_accuracy'], label='val acc') +plt.legend() +plt.show() +plt.savefig('AccVal_acc') + +#save our model in order to use it in web development +#model.save('Esophageal_model.h5') +