|
a |
|
b/README.md |
|
|
1 |
# Volumetric Brain Tumor Segmentation |
|
|
2 |
This repository experiments with best techniques to improve dense, volumetric semantic segmentation. Specifically, the model is of U-net architectural style and includes variational autoencoder (for regularization), residual blocks, spatial and channel squeeze-excitation layers, and dense connections. |
|
|
3 |
|
|
|
4 |
## Model |
|
|
5 |
This is a variation of the [U-net](https://arxiv.org/pdf/1606.06650.pdf) architecture with [variational autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf). There are several architectural enhancements, including |
|
|
6 |
- [Spatial and channel squeeze-excitation layers](https://arxiv.org/abs/1803.02579) in the ResNet blocks. |
|
|
7 |
- [Dense connections](https://arxiv.org/pdf/1608.06993.pdf) between encoder ResNet blocks at the same spatial resolution level. |
|
|
8 |
- Convolutional layers to consist of order `[Conv3D, GroupNorm, ReLU]`, except for all pointwise and output layers. |
|
|
9 |
- He normal initialization for *all* layer kernels except those with sigmoid activations, which are initialized with Glorot normal. |
|
|
10 |
- Convolutional downsampling and upsampling operations. |
|
|
11 |
|
|
|
12 |
## Usage |
|
|
13 |
Dependencies are only supported for Python3 and can be found in `requirements.txt` (`numpy==1.15` for preprocessing and `tensorflow==2.0.0-alpha0` for model architecture, utilizing `tf.keras.Model` and `tf.keras.Layer` subclassing). |
|
|
14 |
|
|
|
15 |
The model can be found in `model/model.py` and contains an `inference` mode in addition to the `training` mode that `tf.Keras.Model` supports. |
|
|
16 |
- Specify `training=False, inference=True` to only receive the decoder output, as desired in test time. |
|
|
17 |
- Specify `training=False, inference=False` to receive both the decoder and variational autoencoder output to be able to run loss and metrics, as desired in validation time. |
|
|
18 |
|
|
|
19 |
### BraTS Data |
|
|
20 |
The BraTS 2017/2018 dataset is not publicly available, so download scripts for those are not available. Once downloaded, run preprocessing on the original data format, which should look something like this: |
|
|
21 |
``` |
|
|
22 |
BraTS17TrainingData/*/*/*[t1,t1ce,t2,flair,seg].nii.gz |
|
|
23 |
``` |
|
|
24 |
|
|
|
25 |
### Preprocessing |
|
|
26 |
For each example, there are 4 modalities and 1 label, each of shape `240 x 240 x 155`. Preprocessing steps consist of: |
|
|
27 |
- Concatenate the `t1ce` and `flair` modalities along the channel dimension. |
|
|
28 |
- Compute per-channel image-wise `mean` and `std` and normalize per channel *for the training set*. |
|
|
29 |
- Crop as much background as possible across all images. Final image sizes are `155 x 190 x 147`. |
|
|
30 |
- Serialize to `tf.TFRecord` format for convenience in training. |
|
|
31 |
|
|
|
32 |
``` |
|
|
33 |
python preprocess.py \ |
|
|
34 |
--in_locs /path/to/BraTS17TrainingData \ |
|
|
35 |
--modalities t1ce,flair \ |
|
|
36 |
--truth seg \ |
|
|
37 |
--create_val |
|
|
38 |
``` |
|
|
39 |
|
|
|
40 |
> All command-line arguments can be found in `args.py`. |
|
|
41 |
|
|
|
42 |
> There are 285 training examples in the BraTS 2017/2018 training sets, but for lack of validation set, the `--create_val` flag creates a 10:1 split, resulting in 260 and 25 training and validation examples, respectively. |
|
|
43 |
|
|
|
44 |
### Training |
|
|
45 |
Most hyperparameters proposed in the paper are used in training. The input is randomly flipped across spatial axes with probability 0.5 and cropped to `128 x 128 x 128` per example in training (making the training data stochastic). The validation set is dynamically created each epoch in a similar fashion. |
|
|
46 |
``` |
|
|
47 |
python train.py \ |
|
|
48 |
--train_loc /path/to/train \ |
|
|
49 |
--val_loc /path/to/val \ |
|
|
50 |
--prepro_file /path/to/prepro/prepro.npy \ |
|
|
51 |
--save_folder checkpoint \ |
|
|
52 |
--crop_size 128,128,128 |
|
|
53 |
``` |
|
|
54 |
|
|
|
55 |
> Use the `--gpu` flag to run on GPU. |
|
|
56 |
|
|
|
57 |
### Testing: Generating Segmentation Masks |
|
|
58 |
The testing script `test.py` runs inference on unlabeled data provided as input by generating sample labels on the whole image, padded to a size that is compatible with downsampling. The VAE is not run in inference so the model is actually fully convolutional. |
|
|
59 |
``` |
|
|
60 |
python test.py \ |
|
|
61 |
--in_locs /path/to/test \ |
|
|
62 |
--modalities t1ce,flair \ |
|
|
63 |
--prepro_loc /path/to/prepro/prepro.npy \ |
|
|
64 |
--tumor_model checkpoint |
|
|
65 |
``` |
|
|
66 |
*Training arguments are saved in the checkpoint folder. This bypasses the need for manual model initialization.* |
|
|
67 |
|
|
|
68 |
> The `Interpolator` class is used to interpolate voxel sizes in rescaling so that all inputs can be resized to 1 mm^3. |
|
|
69 |
|
|
|
70 |
> NOTE: `test.py` is not fully debugged and functional. If needed please open an issue. |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
### Skull Stripping |
|
|
74 |
Because BraTS contains skull-stripped images which are uncommon in actual applications, we support training and inference of skull stripping models. The same pipeline can be generalized, but using the NFBS skull-stripping dataset [here](http://preprocessed-connectomes-project.org/NFB_skullstripped/). Note that in model initialization and training, the number of output channels `--out_ch` would be different for these tasks. |
|
|
75 |
|
|
|
76 |
> If the testing data contains skull bits, run skull stripping and tumor segmentation sequentially in inference time by specifying the `--skull_model` flag. All preprocessing and training should work for both tasks as is. |
|
|
77 |
|
|
|
78 |
### Results |
|
|
79 |
We run training on a V100 32GB GPU with a batch size of 1. Each epoch takes around ~12 minutes to run. Below is a sample training curve, using all default model parameters. |
|
|
80 |
|
|
|
81 |
|Epoch|Training Loss|Training Dice Score|Validation Loss|Validation Dice Score| |
|
|
82 |
|:---:|:-----------:|:-----------------:|:-------------:|:-------------------:| |
|
|
83 |
|0 |1.000 |0.134 |0.732 |0.248 | |
|
|
84 |
|50 |0.433 |0.598 |0.413 |0.580 | |
|
|
85 |
|100 |0.386 |0.651 |0.421 |0.575 | |
|
|
86 |
|150 |0.356 |0.676 |0.393 |0.594 | |
|
|
87 |
|200 |0.324 |0.692 |0.349 |0.642 | |
|
|
88 |
|250 |0.295 |0.716 |0.361 |0.630 | |
|
|
89 |
|300 |0.282 |0.729 |0.352 |0.644 | |