a b/train_lstm.py
1
# -*- coding: utf-8 -*-
2
"""Train_LSTM.ipynb
3
**
4
 * This file is part of Hybrid CNN-LSTM for COVID-19 Severity Score Prediction paper.
5
 *
6
 * Written by Ankan Ghosh Dastider and Farhan Sadik.
7
 *
8
 * Copyright (c) by the authors under Apache-2.0 License. Some rights reserved, see LICENSE.
9
 */
10
11
"""
12
13
'''
14
Loading frames of the videos sequentially
15
'''
16
video_types=['Video 01', 'Video 05', 'Video 06', 'Video 07', 'Video 08', 'Video 09', 'Video 10', 'Video 14', 
17
             'Video 15', 'Video 16', 'Video 17', 'Video 20', 'Video 21', 'Video 27', 'Video 29']
18
19
NUM_VIDEOS = len(video_types)
20
NUM_FRAMES = 302
21
22
data_dir_lstm = ''  #Link Training Directory videowise
23
train_dir_lstm = os.path.join(data_dir_lstm)
24
25
train_data_lstm = []
26
for defects_id, sp in enumerate(video_types):
27
    temporary = []
28
    for file in sorted(os.listdir(os.path.join(train_dir_lstm, sp))):
29
        temporary.append(['{}/{}'.format(sp, file), defects_id, sp])
30
31
    total_frames = len(temporary)
32
    index = np.linspace(start = 0, stop = total_frames-1, num = NUM_FRAMES, dtype = int)
33
34
    for i in range(NUM_FRAMES):
35
        train_data_lstm.append(temporary[index[i]])
36
        
37
train_on_lstm = pd.DataFrame(train_data_lstm, columns=['File', 'FolderID','Video Type'])
38
train_on_lstm.head(NUM_VIDEOS*NUM_FRAMES)
39
40
video_types=['Video 01', 'Video 05', 'Video 06', 'Video 07', 'Video 08', 'Video 09', 'Video 10', 'Video 14', 
41
             'Video 15', 'Video 16', 'Video 17', 'Video 20', 'Video 21', 'Video 27', 'Video 29']
42
43
data_dir_lstm = '' #Link Training Directory videowise
44
train_dir_lstm = os.path.join(data_dir_lstm)
45
46
train_data_lstm = []
47
for defects_id, sp in enumerate(video_types):
48
    for file in sorted(os.listdir(os.path.join(train_dir_lstm, sp))):
49
        # print(file)
50
        train_data_lstm.append(['{}/{}'.format(sp, file), defects_id, sp])
51
        
52
train_on_lstm = pd.DataFrame(train_data_lstm, columns=['File', 'FolderID','Video Type'])
53
train_on_lstm.head()
54
55
IMAGE_SIZE = 128
56
SEED = 42
57
58
BATCH_SIZE_LSTM = 25
59
EPOCHS_LSTM = 120
60
61
def read_image_lstm(filepath):
62
    return cv2.imread(os.path.join(data_dir_lstm, filepath)) # Loading a color image is the default flag
63
64
#Resize image to target size
65
def resize_image(newimage, image_size):
66
    return cv2.resize(newimage.copy(), image_size, interpolation=cv2.INTER_AREA)
67
68
from tensorflow.keras.models import load_model
69
import re
70
from keras import backend as K
71
72
X_Train_Total = np.zeros((NUM_VIDEOS, NUM_FRAMES, IMAGE_SIZE, IMAGE_SIZE, 3))
73
Y_Train_Total = np.zeros((NUM_VIDEOS, NUM_FRAMES, 1))
74
k = 0
75
j = 0
76
for i, file in tqdm(enumerate(train_on_lstm['File'].values)):
77
    if i % NUM_FRAMES == 0 and i != 0 :
78
        k = k + 1
79
        j = 0
80
    if k == NUM_VIDEOS:
81
        break
82
    # print(i,file)
83
    newimage = read_image_lstm(file)
84
    if newimage is not None:
85
        # print(k,j)
86
        X_Train_Total[k,j] = resize_image(newimage, (IMAGE_SIZE, IMAGE_SIZE))
87
        match = re.search('Score(\d)',file)
88
        score = int(match.group(1))
89
        Y_Train_Total[k,j] = score
90
        #print(file)
91
        #print(score)
92
        #print(Y_test[k,j])
93
    j = j + 1
94
95
Y_Train_Total = to_categorical(Y_Train_Total, num_classes=4)
96
# print(Y_Train_Total)
97
# Normalize the data
98
X_Train_Total = X_Train_Total / 255.
99
print('X_Train_Total Shape: {}'.format(X_Train_Total.shape))
100
print('Y_Train_Total Shape: {}'.format(Y_Train_Total.shape))
101
102
np.random.seed(42)
103
np.random.shuffle(X_Train_Total)
104
105
np.random.seed(42)
106
np.random.shuffle(Y_Train_Total)
107
108
print('X_Train_Total Shape: {}'.format(X_Train_Total.shape))
109
print('Y_Train_Total Shape: {}'.format(Y_Train_Total.shape))
110
111
model = load_model('') #Link the CNN weights
112
model.summary()
113
114
output = np.zeros((NUM_VIDEOS, NUM_FRAMES, 64))
115
116
for i in range(NUM_VIDEOS):
117
  X_New = X_Train_Total[i]
118
  specific_layer_output = K.function([model.layers[0].input], [model.get_layer('dropout_35').output])
119
  layer_output = specific_layer_output([X_New])[0]
120
  #print(layer_output.shape)
121
  #print(layer_output)
122
  output[i] = layer_output
123
124
print('Output from CNN Shape: {}'.format(output.shape))
125
#custom3 = model.predict(X_Test)
126
#print(custom3)
127
128
X_Train_Total = output
129
Y_Train_Total = Y_Train_Total
130
131
print('X_Train_Total Shape: {}'.format(X_Train_Total.shape))
132
print('Y_Train_Total Shape: {}'.format(Y_Train_Total.shape))
133
134
# Split the train and validation sets 
135
X_Train_LSTM, X_Val_LSTM, Y_Train_LSTM, Y_Val_LSTM = train_test_split(X_Train_Total, Y_Train_Total, 
136
                                                                      test_size=0.2, random_state = SEED)
137
138
from keras.layers import Reshape, LSTM, Lambda, TimeDistributed, Conv1D, MaxPool1D, Dense, Dropout, Flatten, Conv2D, MaxPool2D, BatchNormalization, AveragePooling2D, GlobalAveragePooling2D
139
140
141
def build_lstm():
142
143
    input = Input(shape=(NUM_FRAMES, 64))
144
    
145
146
    x = LSTM(1000, return_sequences = True)(input)
147
    x = Dropout(0.5)(x)
148
    
149
    x = LSTM(1000, return_sequences = True)(x)
150
    x = Dropout(0.5)(x)
151
152
    x = LSTM(4, return_sequences=True)(x)
153
    # multi output
154
    output = Dense(4,activation = 'softmax', name='root')(x)
155
156
    # model
157
    model = Model(input,output)
158
    
159
    optimizer = Adam(lr=0.002, beta_1=0.9, beta_2=0.999, epsilon=0.1, decay=0.0)
160
    model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
161
    model.summary()
162
    
163
    return model
164
165
model_lstm = build_lstm()
166
annealer = ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=5, verbose=1, min_lr=1e-3)
167
checkpoint = ModelCheckpoint('model_lstm.h5', verbose=1, save_best_only=True)
168
# Generates batches of image data with data augmentation
169
# datagen = ImageDataGenerator(rotation_range=360, # Degree range for random rotations
170
 #                       width_shift_range=0.2, # Range for random horizontal shifts
171
  #                      height_shift_range=0.2, # Range for random vertical shifts
172
   #                     zoom_range=0.2, # Range for random zoom
173
    #                    horizontal_flip=True, # Randomly flip inputs horizontally
174
     #                   vertical_flip=True) # Randomly flip inputs vertically
175
176
#datagen.fit(X_train)
177
# Fits the model on batches with real-time data augmentation
178
hist = model_lstm.fit(X_Train_LSTM, Y_Train_LSTM, batch_size = BATCH_SIZE_LSTM,
179
                     # steps_per_epoch=X_Train_LSTM.shape[0] // BATCH_SIZE,
180
                     epochs = EPOCHS_LSTM,
181
                     verbose = 2,
182
                     callbacks = [annealer, checkpoint],
183
                     validation_data = (X_Val_LSTM, Y_Val_LSTM))
184
185
final_loss_lstm, final_accuracy_lstm = model_lstm.evaluate(X_Val_LSTM, Y_Val_LSTM)
186
print('Final Loss LSTM: {}, Final Accuracy LSTM: {}'.format(final_loss_lstm, final_accuracy_lstm))
187
188
score_types = ['Score 0', 'Score 1', 'Score 2', 'Score 3']
189
190
Y_pred_lstm = model_lstm.predict(X_Val_LSTM)
191
Y_pred_lstm = np.reshape(Y_pred_lstm, (Y_pred_lstm.shape[0]*Y_pred_lstm.shape[1], Y_pred_lstm.shape[2]))
192
Y_pred_lstm = np.argmax(Y_pred_lstm, axis=1)
193
194
Y_true_lstm = np.reshape(Y_Val_LSTM, (Y_Val_LSTM.shape[0]*Y_Val_LSTM.shape[1], Y_Val_LSTM.shape[2]))
195
Y_true_lstm = np.argmax(Y_true_lstm, axis=1)
196
197
#print(Y_pred_lstm.shape)
198
#print(Y_Val_LSTM.shape)
199
cm = confusion_matrix(Y_true_lstm, Y_pred_lstm)
200
plt.figure(figsize=(12, 12))
201
ax = sns.heatmap(cm, cmap=plt.cm.Greens, annot=True, square=True, xticklabels=score_types, yticklabels=score_types)
202
ax.set_ylabel('Actual', fontsize=40)
203
ax.set_xlabel('Predicted', fontsize=40)
204
205
'''
206
# accuracy plot 
207
plt.plot(hist.history['accuracy'])
208
plt.plot(hist.history['val_accuracy'])
209
plt.title('model accuracy')
210
plt.ylabel('accuracy')
211
plt.xlabel('epoch')
212
plt.legend(['train', 'test'], loc='upper left')
213
plt.show()
214
215
# loss plot
216
plt.plot(hist.history['loss'])
217
plt.plot(hist.history['val_loss'])
218
plt.title('model loss')
219
plt.ylabel('loss')
220
plt.xlabel('epoch')
221
plt.legend(['train', 'test'], loc='upper left')
222
plt.show()
223
'''