a b/README.md
1
# Volumetric Brain Tumor Segmentation
2
This repository experiments with best techniques to improve dense, volumetric semantic segmentation. Specifically, the model is of U-net architectural style and includes variational autoencoder (for regularization), residual blocks, spatial and channel squeeze-excitation layers, and dense connections.
3
4
## Model
5
This is a variation of the [U-net](https://arxiv.org/pdf/1606.06650.pdf) architecture with [variational autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf). There are several architectural enhancements, including
6
 - [Spatial and channel squeeze-excitation layers](https://arxiv.org/abs/1803.02579) in the ResNet blocks.
7
 - [Dense connections](https://arxiv.org/pdf/1608.06993.pdf) between encoder ResNet blocks at the same spatial resolution level.
8
 - Convolutional layers to consist of order `[Conv3D, GroupNorm, ReLU]`, except for all pointwise and output layers.
9
 - He normal initialization for *all* layer kernels except those with sigmoid activations, which are initialized with Glorot normal.
10
 - Convolutional downsampling and upsampling operations.
11
12
## Usage
13
Dependencies are only supported for Python3 and can be found in `requirements.txt` (`numpy==1.15` for preprocessing and `tensorflow==2.0.0-alpha0` for model architecture, utilizing `tf.keras.Model` and `tf.keras.Layer` subclassing).
14
15
The model can be found in `model/model.py` and contains an `inference` mode in addition to the `training` mode that `tf.Keras.Model` supports.
16
 - Specify `training=False, inference=True` to only receive the decoder output, as desired in test time.
17
 - Specify `training=False, inference=False` to receive both the decoder and variational autoencoder output to be able to run loss and metrics, as desired in validation time.
18
19
### BraTS Data
20
The BraTS 2017/2018 dataset is not publicly available, so download scripts for those are not available. Once downloaded, run preprocessing on the original data format, which should look something like this:
21
```
22
BraTS17TrainingData/*/*/*[t1,t1ce,t2,flair,seg].nii.gz
23
```
24
25
### Preprocessing
26
For each example, there are 4 modalities and 1 label, each of shape `240 x 240 x 155`. Preprocessing steps consist of:
27
 - Concatenate the `t1ce` and `flair` modalities along the channel dimension.
28
 - Compute per-channel image-wise `mean` and `std` and normalize per channel *for the training set*.
29
 - Crop as much background as possible across all images. Final image sizes are `155 x 190 x 147`.
30
 - Serialize to `tf.TFRecord` format for convenience in training.
31
32
```
33
python preprocess.py \
34
    --in_locs /path/to/BraTS17TrainingData \
35
    --modalities t1ce,flair \
36
    --truth seg \
37
    --create_val
38
```
39
40
> All command-line arguments can be found in `args.py`.
41
42
> There are 285 training examples in the BraTS 2017/2018 training sets, but for lack of validation set, the `--create_val` flag creates a 10:1 split, resulting in 260 and 25 training and validation examples, respectively.
43
44
### Training
45
Most hyperparameters proposed in the paper are used in training. The input is randomly flipped across spatial axes with probability 0.5 and cropped to `128 x 128 x 128` per example in training (making the training data stochastic). The validation set is dynamically created each epoch in a similar fashion.
46
```
47
python train.py \
48
    --train_loc /path/to/train \
49
    --val_loc /path/to/val \
50
    --prepro_file /path/to/prepro/prepro.npy \
51
    --save_folder checkpoint \
52
    --crop_size 128,128,128
53
```
54
55
> Use the `--gpu` flag to run on GPU.
56
57
### Testing: Generating Segmentation Masks
58
The testing script `test.py` runs inference on unlabeled data provided as input by generating sample labels on the whole image, padded to a size that is compatible with downsampling. The VAE is not run in inference so the model is actually fully convolutional.
59
```
60
python test.py \
61
    --in_locs /path/to/test \
62
    --modalities t1ce,flair \
63
    --prepro_loc /path/to/prepro/prepro.npy \
64
    --tumor_model checkpoint
65
```
66
*Training arguments are saved in the checkpoint folder. This bypasses the need for manual model initialization.*
67
68
> The `Interpolator` class is used to interpolate voxel sizes in rescaling so that all inputs can be resized to 1 mm^3.
69
70
> NOTE: `test.py` is not fully debugged and functional. If needed please open an issue.
71
72
73
### Skull Stripping
74
Because BraTS contains skull-stripped images which are uncommon in actual applications, we support training and inference of skull stripping models. The same pipeline can be generalized, but using the NFBS skull-stripping dataset [here](http://preprocessed-connectomes-project.org/NFB_skullstripped/). Note that in model initialization and training, the number of output channels `--out_ch` would be different for these tasks.
75
76
> If the testing data contains skull bits, run skull stripping and tumor segmentation sequentially in inference time by specifying the `--skull_model` flag. All preprocessing and training should work for both tasks as is.
77
78
### Results
79
We run training on a V100 32GB GPU with a batch size of 1. Each epoch takes around ~12 minutes to run. Below is a sample training curve, using all default model parameters.
80
81
|Epoch|Training Loss|Training Dice Score|Validation Loss|Validation Dice Score|
82
|:---:|:-----------:|:-----------------:|:-------------:|:-------------------:|
83
|0    |1.000        |0.134              |0.732          |0.248                |
84
|50   |0.433        |0.598              |0.413          |0.580                |
85
|100  |0.386        |0.651              |0.421          |0.575                |
86
|150  |0.356        |0.676              |0.393          |0.594                |
87
|200  |0.324        |0.692              |0.349          |0.642                |
88
|250  |0.295        |0.716              |0.361          |0.630                |
89
|300  |0.282        |0.729              |0.352          |0.644                |