Diff of /README.md [000000] .. [562311]

Switch to unified view

a b/README.md
1
---
2
jupyter:
3
  jupytext:
4
    formats: ipynb,md
5
    text_representation:
6
      extension: .md
7
      format_name: markdown
8
      format_version: '1.3'
9
      jupytext_version: 1.14.4
10
  kernelspec:
11
    display_name: Python 3 (ipykernel)
12
    language: python
13
    name: python3
14
---
15
16
<!-- #region -->
17
# 3D Image Classification
18
19
Learn how to train a 3D convolutional neural network (3D CNN) to predict presence of pneumonia - based on [Tutorial on 3D Image Classification](https://keras.io/examples/vision/3D_image_classification/) by [Hasib Zunair](https://github.com/hasibzunair).
20
21
22
> __Dataset__: [MosMedData: Chest CT Scans with COVID-19 Related Findings Dataset](https://www.medrxiv.org/content/10.1101/2020.05.20.20100362v1) :: This dataset contains anonymised human lung computed tomography (CT) scans with COVID-19 related findings, as well as without such findings. A small subset of studies has been annotated with binary pixel masks depicting regions of interests (ground-glass opacifications and consolidations). CT scans were obtained between 1st of March, 2020 and 25th of April, 2020, and provided by municipal hospitals in Moscow, Russia.
23
<!-- #endregion -->
24
25
## Verify GPU Support
26
27
```python
28
# importing tensorflow
29
import tensorflow as tf
30
31
device_name = tf.test.gpu_device_name()
32
print('Active GPU :: {}'.format(device_name))
33
# Active GPU :: /device:GPU:0
34
```
35
36
## Import Dependencies
37
38
```python
39
import matplotlib.pyplot as plt
40
import nibabel as nib
41
import numpy as np
42
import os
43
import random
44
from scipy import ndimage
45
from sklearn.model_selection import train_test_split
46
from tensorflow import keras
47
from tensorflow.keras import layers
48
from tensorflow.keras.utils import plot_model
49
```
50
51
```python
52
# helper functions
53
from helper import (read_scan,
54
                    normalize,
55
                    resize_volume,
56
                    process_scan,
57
                    rotate,
58
                    train_preprocessing,
59
                    validation_preprocessing,
60
                    plot_slices,
61
                    build_model)
62
```
63
64
## Import Dataset
65
66
```python
67
# download from https://github.com/hasibzunair/3D-image-classification-tutorial/releases/
68
data_dir = './dataset'
69
no_pneumonia = os.path.join(data_dir, 'no_viral_pneumonia')
70
with_pneumonia = os.path.join(data_dir, 'with_viral_pneumonia')
71
72
normal_scan_paths = [
73
    os.path.join(no_pneumonia, i)
74
    for i in os.listdir(no_pneumonia)
75
]
76
print('INFO :: CT Scans with normal lung tissue:', len(normal_scan_paths))
77
78
abnormal_scan_paths = [
79
    os.path.join(with_pneumonia, i)
80
    for i in os.listdir(with_pneumonia)
81
]
82
print('INFO :: CT Scans with abnormal lung tissue:', len(abnormal_scan_paths))
83
84
# INFO :: CT Scans with normal lung tissue: 100
85
# INFO :: CT Scans with abnormal lung tissue: 100
86
```
87
88
## Visualize Dataset
89
90
```python
91
img_normal = nib.load(normal_scan_paths[0])
92
img_normal_array = img_normal.get_fdata()
93
94
img_abnormal = nib.load(abnormal_scan_paths[0])
95
img_abnormal_array = img_abnormal.get_fdata()
96
97
plt.figure(figsize=(30,10))
98
99
for i in range(6):
100
    plt.subplot(2, 6, i+1)
101
    plt.imshow(img_normal_array[:, :, i], cmap='Blues')
102
    plt.axis('off')
103
    plt.title('Slice {} - Normal'.format(i))
104
    
105
    plt.subplot(2, 6, 6+i+1)
106
    plt.imshow(img_abnormal_array[:, :, i], cmap='Reds')
107
    plt.axis('off')
108
    plt.title('Slice {} - Abnormal'.format(i))
109
```
110
111
![3D Image Classification](https://github.com/mpolinowski/deep-3d-image-segmentation/blob/master/assets/3D_Image_Classification_01.png)
112
113
114
## Data Pre-processing
115
116
### Normalization
117
118
```python
119
# Read and process the scans.
120
# Each scan is resized across height, width, and depth and rescaled.
121
abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
122
normal_scans = np.array([process_scan(path) for path in normal_scan_paths])
123
```
124
125
```python
126
# For the CT scans having presence of viral pneumonia
127
# assign 1, for the normal ones assign 0.
128
abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
129
normal_labels = np.array([0 for _ in range(len(normal_scans))])
130
```
131
132
### Train Test Split
133
134
```python
135
X = np.concatenate((abnormal_scans, normal_scans), axis=0)
136
Y = np.concatenate((abnormal_labels, normal_labels), axis=0)
137
138
x_train, x_val, y_train, y_val = train_test_split(X, Y, test_size=0.3, random_state=42)
139
print('INFO :: Train / Test Samples - %d / %d' % (x_train.shape[0], x_val.shape[0]))
140
# INFO :: Train / Test Samples - 140 / 60
141
```
142
143
### Data Augmentation
144
145
146
#### Data Loader
147
148
```python
149
# Define data loaders.
150
train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
151
validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))
152
batch_size = 2
153
154
# Augment the on the fly during training.
155
train_dataset = (
156
    train_loader.shuffle(len(x_train))
157
    .map(train_preprocessing)
158
    .batch(batch_size)
159
    .prefetch(2)
160
)
161
# Only rescale.
162
validation_dataset = (
163
    validation_loader.shuffle(len(x_val))
164
    .map(validation_preprocessing)
165
    .batch(batch_size)
166
    .prefetch(2)
167
)
168
```
169
170
### Visualizing Augmented Datasets
171
172
```python
173
data = train_dataset.take(1)
174
images, labels = list(data)[0]
175
images = images.numpy()
176
image = images[0]
177
print("CT Scan Dims:", image.shape)
178
# CT Scan Dims: (128, 128, 64, 1)
179
plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")
180
181
# Visualize montage of slices.
182
# 4 rows and 10 columns for 100 slices of the CT scan.
183
plot_slices(4, 10, 128, 128, image[:, :, :40])
184
```
185
186
![3D Image Classification](https://github.com/mpolinowski/deep-3d-image-segmentation/blob/master/assets/3D_Image_Classification_02.png)
187
188
![3D Image Classification](https://github.com/mpolinowski/deep-3d-image-segmentation/blob/master/assets/3D_Image_Classification_03.png)
189
190
191
## Building the Model
192
193
```python
194
model = build_model(width=128, height=128, depth=64)
195
model.summary()
196
```
197
198
<!-- #region -->
199
```bash
200
Model: "3dctscan"
201
_________________________________________________________________
202
 Layer (type)                Output Shape              Param #   
203
=================================================================
204
 input_3 (InputLayer)        [(None, 128, 128, 64, 1)  0         
205
                             ]                                   
206
                                                                 
207
 conv3d_9 (Conv3D)           (None, 126, 126, 62, 64)  1792      
208
                                                                 
209
 max_pooling3d_8 (MaxPooling  (None, 63, 63, 31, 64)   0         
210
 3D)                                                             
211
                                                                 
212
 batch_normalization_8 (Batc  (None, 63, 63, 31, 64)   256       
213
 hNormalization)                                                 
214
                                                                 
215
 conv3d_10 (Conv3D)          (None, 61, 61, 29, 64)    110656    
216
                                                                 
217
 max_pooling3d_9 (MaxPooling  (None, 30, 30, 14, 64)   0         
218
 3D)                                                             
219
                                                                 
220
 batch_normalization_9 (Batc  (None, 30, 30, 14, 64)   256       
221
 hNormalization)                                                 
222
                                                                 
223
 conv3d_11 (Conv3D)          (None, 28, 28, 12, 128)   221312    
224
                                                                 
225
 max_pooling3d_10 (MaxPoolin  (None, 14, 14, 6, 128)   0         
226
 g3D)                                                            
227
                                                                 
228
 batch_normalization_10 (Bat  (None, 14, 14, 6, 128)   512       
229
 chNormalization)                                                
230
                                                                 
231
 conv3d_12 (Conv3D)          (None, 12, 12, 4, 256)    884992    
232
                                                                 
233
 max_pooling3d_11 (MaxPoolin  (None, 6, 6, 2, 256)     0         
234
 g3D)                                                            
235
                                                                 
236
 batch_normalization_11 (Bat  (None, 6, 6, 2, 256)     1024      
237
 chNormalization)                                                
238
                                                                 
239
 global_average_pooling3d_1   (None, 256)              0         
240
 (GlobalAveragePooling3D)                                        
241
                                                                 
242
 dense_2 (Dense)             (None, 512)               131584    
243
                                                                 
244
 dropout_1 (Dropout)         (None, 512)               0         
245
                                                                 
246
 dense_3 (Dense)             (None, 1)                 513       
247
                                                                 
248
=================================================================
249
Total params: 1,352,897
250
Trainable params: 1,351,873
251
Non-trainable params: 1,024
252
_________________________________________________________________
253
```
254
<!-- #endregion -->
255
256
### Compile the Model
257
258
```python
259
initial_learning_rate = 0.0001
260
261
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
262
    initial_learning_rate,
263
    decay_steps=100000,
264
    decay_rate=0.96,
265
    staircase=True)
266
267
model.compile(
268
    loss='binary_crossentropy',
269
    optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
270
    metrics=['acc']
271
)
272
```
273
274
### Callbacks
275
276
```python
277
cp_cb = keras.callbacks.ModelCheckpoint(
278
    './checkpoints/3dct_weights.h5',
279
    save_best_only=True
280
)
281
282
es_cb = keras.callbacks.EarlyStopping(
283
    monitor='val_acc',
284
    patience=15
285
)
286
```
287
288
## Model Training
289
290
```python
291
epochs = 100
292
293
model.fit(
294
    train_dataset,
295
    validation_data=validation_dataset,
296
    epochs=epochs,
297
    shuffle=True,
298
    verbose=2,
299
    callbacks=[cp_cb, es_cb]
300
)
301
# Epoch 46/100
302
# 70/70 - 22s - loss: 0.3383 - acc: 0.8429 - val_loss: 0.8225 - val_acc: 0.6833 - 22s/epoch - 313ms/step
303
```
304
305
### Visualizing Model Performance
306
307
```python
308
fig, ax = plt.subplots(1, 2, figsize=(20, 3))
309
ax = ax.ravel()
310
311
for i, metric in enumerate(["acc", "loss"]):
312
    ax[i].plot(model.history.history[metric])
313
    ax[i].plot(model.history.history["val_" + metric])
314
    ax[i].set_title("Model {}".format(metric))
315
    ax[i].set_xlabel("epochs")
316
    ax[i].set_ylabel(metric)
317
    ax[i].legend(["train", "val"])
318
```
319
320
![3D Image Classification](https://github.com/mpolinowski/deep-3d-image-segmentation/blob/master/assets/3D_Image_Classification_04.png)
321
322
323
## Loading Best Training Weights
324
325
```python
326
model.load_weights('./checkpoints/3dct_weights.h5')
327
```
328
329
## Make Predictions
330
331
```python
332
pred_dataset = './predictions'
333
pred_paths = [os.path.join(pred_dataset, i) for i in os.listdir(pred_dataset)]
334
335
z_val = np.array([process_scan(path) for path in pred_paths])
336
337
for i in range(len(z_val)):
338
    prediction = model.predict(np.expand_dims(z_val[i], axis=0))[0]
339
    scores = [1 - prediction[0], prediction[0]]
340
    class_names = ['normal', 'abnormal']
341
    
342
pred_image = nib.load(pred_paths[i])
343
pred_image_data = pred_image.get_fdata()
344
345
normal_class = class_names[0], round(100*scores[0], 2)
346
abnormal_class = class_names[1], round(100*scores[1], 2)
347
annotation = normal_class + abnormal_class
348
349
plt.imshow(pred_image_data[:,:, pred_image_data.shape[2]//2], cmap='gray')
350
plt.title(os.path.basename(pred_paths[i]))
351
plt.xlabel(annotation)
352
plt.show()
353
```
354
355
![3D Image Classification](https://github.com/mpolinowski/deep-3d-image-segmentation/blob/master/assets/3D_Image_Classification_05.png)
356
357
```python
358
print(scores, class_names)
359
```
360
361
```python
362
363
```