a b/training-models/cnn_lstm_model.py
1
# import necessary libraries
2
from keras.preprocessing.image import ImageDataGenerator
3
from keras.applications.vgg16 import VGG16
4
from keras.layers import Dense, LSTM, Dropout, Flatten, Reshape
5
from keras.models import Sequential
6
7
# set up the data generators
8
train_dir = 'X-ray Images/train'
9
val_dir = 'X-ray Images/validation'
10
test_dir = 'X-ray Images/test'
11
img_height = 224
12
img_width = 224
13
batch_size = 32
14
15
train_datagen = ImageDataGenerator(rescale=1./255)
16
val_datagen = ImageDataGenerator(rescale=1./255)
17
test_datagen = ImageDataGenerator(rescale=1./255)
18
19
train_generator = train_datagen.flow_from_directory(
20
    train_dir,
21
    target_size=(img_height, img_width),
22
    batch_size=batch_size,
23
    class_mode='categorical')
24
25
val_generator = val_datagen.flow_from_directory(
26
    val_dir,
27
    target_size=(img_height, img_width),
28
    batch_size=batch_size,
29
    class_mode='categorical')
30
31
test_generator = test_datagen.flow_from_directory(
32
    test_dir,
33
    target_size=(img_height, img_width),
34
    batch_size=batch_size,
35
    class_mode='categorical')
36
37
# load the pre-trained VGG16 model and extract features from the images
38
vgg_model = VGG16(weights='imagenet', include_top=False, input_shape=(img_height, img_width, 3))
39
40
# freeze the layers in the VGG16 model
41
for layer in vgg_model.layers:
42
    layer.trainable = False
43
44
# create the CNN-LSTM model
45
model = Sequential()
46
model.add(vgg_model)
47
model.add(Flatten())
48
model.add(Reshape((1, -1)))
49
model.add(LSTM(256, return_sequences=False))
50
model.add(Dropout(0.5))
51
model.add(Dense(128, activation='relu'))
52
model.add(Dropout(0.5))
53
model.add(Dense(train_generator.num_classes, activation='softmax'))
54
55
# compile the model
56
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
57
58
# train the model
59
history = model.fit_generator(
60
    train_generator,
61
    steps_per_epoch=train_generator.n // batch_size,
62
    epochs=10,
63
    validation_data=val_generator,
64
    validation_steps=val_generator.n // batch_size)
65
66
# evaluate the model on the test data
67
test_loss, test_acc = model.evaluate_generator(test_generator, steps=test_generator.n // batch_size)
68
print('Test accuracy:', test_acc)
69
70
# make predictions on new data
71
predictions = model.predict_generator(test_generator, steps=test_generator.n // batch_size)