Diff of /README.md [000000] .. [a49583]

Switch to unified view

a b/README.md
1
#######################################################################################
2
### We are looking for a new PhD student, to work on COVID-19 segmentation in CTs and CXRs! 
3
### For the details see https://cit-ai.net/PhD-Scholarships.html#6 or email me alex.ter-sarkisov@city.ac.uk
4
#######################################################################################
5
## Update from 14/09/22: Published in Applied Intelligence, January 2022, Volume 52, pages 9664–9675
6
7
## Update from 21/11/21: to appear in Applied Intelligence
8
9
# COVID-CT-Mask-Net: Prediction of COVID-19 From CT Scans Using Regional Features
10
11
[Presentation](https://github.com/AlexTS1980/COVID-CT-Mask-Net/blob/master/presentations/COVID_19_Presentation_Kent.pdf) at the University of Kent, 23-11-2020:
12
13
<p align="center">
14
<img src="https://github.com/AlexTS1980/COVID-CT-Mask-Net/blob/master/plots/Kent_231120.png" width="800" height="400" align="center"/>
15
</p>
16
17
Papers on medrXiv: 
18
19
[Lightweight Model For The Prediction of COVID-19 Through The Detection And Segmentation of Lesions in Chest CT Scans](https://www.medrxiv.org/content/10.1101/2020.10.30.20223586v2.full.pdf)
20
21
[Detection and Segmentation of Lesion Areas in Chest CT Scans For The Prediction of COVID-19](https://www.medrxiv.org/content/10.1101/2020.10.23.20218461v2.full.pdf)
22
23
[COVID-CT-Mask-Net: Prediction of COVID-19 From CT Scans Using Regional Features](https://www.medrxiv.org/content/10.1101/2020.10.11.20211052v2.full.pdf)
24
25
Bibtex citation (preprints): 
26
27
```
28
29
@article {Ter-Sarkisov2020.10.11.20211052,
30
    author = {Ter-Sarkisov, Aram},
31
    title = {COVID-CT-Mask-Net: Prediction of COVID-19 from CT Scans Using Regional Features},
32
    year = {2020},
33
    doi = {10.1101/2020.10.11.20211052},
34
    publisher = {Cold Spring Harbor Laboratory Press},
35
    journal = {medRxiv}
36
    
37
@article {Ter-Sarkisov2020.10.30.20223586,
38
    author = {Ter-Sarkisov, Aram},
39
    title = {Lightweight Model For The Prediction of COVID-19 Through The Detection And Segmentation
40
    of Lesions in Chest CT Scans},
41
    year = {2020},
42
    doi = {10.1101/2020.10.30.20223586},
43
    publisher = {Cold Spring Harbor Laboratory Press},
44
    journal = {medRxiv}
45
}
46
47
@article {Ter-Sarkisov2020.10.23.20218461,
48
    author = {Ter-Sarkisov, Aram},
49
    title = {Detection and Segmentation of Lesion Areas in Chest CT Scans For The Prediction of COVID-19},
50
    year = {2020},
51
    doi = {10.1101/2020.10.23.20218461},
52
    publisher = {Cold Spring Harbor Laboratory Press},
53
    journal = {medRxiv}
54
}}
55
```
56
Bibtex citation (journal publication):
57
```
58
@article {TerSarkisov2022,
59
    author = {Ter-Sarkisov, Aram},
60
    title = {COVID-CT-Mask-Net: Prediction of COVID-19 from CT Scans Using Regional Features},
61
    year = {2022},
62
    doi = {10.1007/s10489-021-02731-6},
63
    journal = {Applied Intelligence}
64
    volume = {52}
65
    pages = {9664–9675}}
66
```
67
68
## Update 01/11/20
69
I re-implemented torchvision's segmentation interface locally, in the end it was easier to keep two different files for RPN and RoI for segmentation and classification tasks: `rpn_segmentation, roi_segmentation` vs `roi` and `rpn`. For the validation split in `test_split_segmentation.txt` I get the following results for the two lightweight and two best full models (ResNet50+FPN backbone): 
70
71
|  Model    | AP@0.5    | AP@0.75   | mAP@[0.5:0.95:0.05]   | Model size
72
|:-:    |:-:    |:-:    |:-:|:-:    
73
| **Lightweight model (truncated ResNet34+FPN)**    | 59.88%    | 45.06%    | 44.76%    | 11.45M|
74
| **Lightweight model (truncated ResNet18+FPN)**    | 49.95%    | 37.78%    | 39.32%    |6.12M|
75
| **Full model (merged masks)**     | 61.92%    | 45.22%    | 44.68%    |31.78M|
76
| **Full model (GGO + C masks)**        |  50.20%| 41.98%|38.71%|31.78M|
77
78
The penultimate column is the mean over 10 IoU thresholds, the main metric in the MS COCO leaderboard. 
79
80
For each script, two additional arguments were added: `backbone_name`, one of `resnet18, resnet34, resnet50` and `truncation`, one of `0,1,2`. For `resnet50`, only the full (base torchvision model) output is implemented, with 4 connections to FPN. For `resnet18` and `resnet34`, `truncation=0` means use the full backbone model, for `truncation=1` the last block is deleted and `truncation=2` the last two layers are deleted. Only the last layer is connected to the FPN. 
81
82
To evaluate the model, run (e.g., for the lightweight with ResNet18+FPN backbone and truncated last block: 
83
```
84
python3 evaluation_mean_ap.py --backbone_name resnet18 --ckpt model.pth --mask_type merge --truncation 1 --rpn_nms_th 0.75 --roi_nms_th 0.75 --confidence_th 0.75 
85
```
86
To train the segmentation model from scratch to get the results above:
87
```
88
python3.5 train_segmentation.py --num_epochs 100 --mask_type merge --save_every 10 --backbone_name resnet18 --truncation 1 --device cuda
89
```
90
Results of the classification models derived from the segmentation models above (class sensitivity and overall accuracy):
91
92
93
| Model     | Control   | CP    | COVID-19  | Overall accuracy
94
|:-:    |:-:    |:-:    |:-:    |:-:
95
| **Lightweight model (truncated ResNet34+FPN)**    |  92.89%   | 91.70%    | 91.76%    |92.89%|
96
| **Lightweight model (truncated ResNet18+FPN)**    | 96.98%    | 91.63%    | 91.35%    |93.95%|
97
| **Full model (merged masks)**     | 97.74%    | 96.69%    | 92.68%    |96.33%|
98
| **Full model (both masks)**   | 96.91%    | 95.06%    | 93.88%    |95.64%|
99
100
To train a lightweight classifier, you need to specify the backbone name and the truncation level, it must be the same as in the segmentation model from which it is derived. Also, you need to define the size of the RoI batch: roi_batch_size, which is equal to the input in the classification module, number of masks on which the segmentation model was trained `num_class` (2 for merged and 3 for separate) and the number of features in the classification module **S**, `s_features`. You need at least the pretrained weights from a segmentation model. **You cannot train COVID-CT-Mask-Net classifier from scratch.** 
101
```
102
python3 train_classifier.py --pretrained_segmentation_model segmentation_model.pth --backbone_name resnet34 --num_epochs 50 --save_every 10 --num_class 2 --truncation 1 --s_features 512 --roi_batch_size 128 --batch_size 8
103
```
104
In this case the weights for all parameters except **S** are copied from the segmentation model, all parameters in **S** and weights in the batch normalization layers are updated, but the stats in the batch normalization layers (means and variances) are frozen. After about 50 epochs (2h15min hours on an a GPU with 8Gb VRAM) you should get the model with the accuracy reported above. To evaluate the classifier:
105
```
106
python3 evaluate_classifier.py --ckpt classification_model.pth --truncation 1 --num_class 2 --backbone_name resnet34 --roi_batch_size 128 --device cuda --s_features 512
107
```
108
109
## Update 29/10/20
110
Column 1: Input CT scan slice overlaid with the output of the segmentation model. 
111
112
Column 2: Mask maps logit scores (pixel-level) predicted by Mask R-CNN *independently of each other*, i.e. they were output by different RoIs and resized to fit the bounding box prediction. Note COVID-CT-Mask-Net uses a fixed number of RoIs. Only the highest-ranking RoIs are plotted here to avoid the image clutter.
113
114
Column 3: ground truth masks for lesions (yellow) and lungs (green, treated as a background).
115
116
Column 4: true class (green) and logit scores output by COVID-CT-Mask-Net (red) using the score map's inputs. Note how the classification model learns the distribution and ranking of the regional predictions (bounding boxes and confidence scores) to predict the global (image) class.
117
118
<p align="center">
119
<img src="https://github.com/AlexTS1980/COVID-CT-Mask-Net/blob/master/plots/segmentation_map_classification_score.png" width="800" height="500" align="center"/>
120
</p>
121
122
123
## Update 19-22/10/20
124
I added a large number of updates across all models. Now you can train segmentation and classification models with 3 types of masks: two masks (GGO and C), only GGO and merged GGO and C  masks('lesion'). 
125
126
I added methods in the `utils` script to compute the accuracy (mean Average Precision) of Mask R-CNN segmentation models. They are based on matterport's package, but purely in pytorch, no requirements for RLE or pycocotools. A new evaluation script, `evaluation_mean_ap`, which uses these methods for a range of Intersect over Union (IoU) thresholds, has been added too. 
127
128
**COVID-CT-Mask-Net (merged masks)**: COVID-19 sensitivity: 92.68%, overall accuracy: 96.33%
129
|   | Control   | CP    | COVID-19  |
130
|:-:    |:-:    |:-:    |:-:    |
131
| **Control**   | 9236  | 188   | 26    |
132
| **CP**    | 116   | 7150  | 129   |
133
| **COVID-19**  | 20    | 298   | 4028  |
134
135
**COVID-CT-Mask-Net (GGO + C masks)**: COVID-19 sensitivity: 93.88%, overall accuracy: 95.64%
136
|   | Control   | CP    | COVID-19  |
137
|:-:    |:-:    |:-:    |:-:    |
138
| **Control**   | 9158  | 278   | 14    |
139
| **CP**    | 204   | 7030  | 161   |
140
| **COVID-19**  | 15    | 251   | 4080  |
141
142
All segmentation scripts as well as the segmentation dataset interface accept `mask_type` argument, one of `both` (GGO + C), `ggo` (only GGO) and `merge` (merged GGO and C masks). The effect on the size of the model is marginal.       
143
144
## 1. Segmentation Model
145
<p align="center">
146
<img src="https://github.com/AlexTS1980/COVID-CT-Mask-Net/blob/master/plots/maskrcnncovidsegment.png" width="800" height="250" align="center"/>
147
</p>
148
To train and test the model you need Torchvision 0.3.0+
149
150
The segmentation model predicts masks of Ground Glass Opacity and Consolidation in CT scans. We trained it on the CNCB CT images with masks (http://ncov-ai.big.ac.cn/download, Experiment data files): 500 training and 150 for testing taken from COVID-positive patients, but some slices have no
151
lesions. Use the splits in `train_split_segmentation.txt` and `test_split_segmentation.txt` to copy the training data into `covid_data/train/imgs` and `covid_data/train/masks` and test data into `covid_data/test/imgs` and `covid_data/test/masks`. 
152
153
Download the pretrained weights into  `pretrained_models/` directory.
154
155
To get the inference of the segmentation model, run: 
156
```
157
python3.5 inference_segmentation.py --ckpt pretrained_models/segmentation_model_both_classes.pth --test_data_dir covid_data/test --test_imgs_dir imgs --masks_type both
158
```
159
This should output predictions like these:
160
<p align="center">
161
<img src="https://github.com/AlexTS1980/COVID-CT-Mask-Net/blob/master/plots/128_92_with_mask.png" width="600" height="200" align="center"/>
162
<img src="https://github.com/AlexTS1980/COVID-CT-Mask-Net/blob/master/plots/133_48_with_mask.png" width="600" height="200" align="center"/>
163
</p>
164
165
For the explanation of plots see the paper. To get the average precision on the data, you also need the mask for each image. For example, for merged masks:
166
```
167
python3.5 evaluation_mean_ap.py --ckpt pretrained_weights/segmentation_model_merged_masks.pth --mask_type merge --test_data_dir covid_data/test --test_imgs_dir imgs --gt_dir masks
168
```
169
To train the segmentation model, you also need images with masks. Dataset interface `dataset_segmentation.py` converts masks into binary masks with either 2 positive classes (GGO+C) or 1 (GGO only, merged GGO+C). It also extracts labels and bounding boxes that Mask R-CNN requires. 
170
To train from scratch for the merged masks, run 
171
```
172
python3.5 train_segmentation.py --device cuda --num_epochs 100 --use_pretrained_model False -use_pretrained_backbone True --save_every 10 --mask_type merge
173
```
174
To get the reported results, and for the COVID-CT-Mask-Net classsifier, we trained the model for 100 epochs (about 4.5 hours on a GPU with 8Gb VRAM).   
175
176
## 2. COVID-CT-Mask-Net (Classification Model) 
177
178
### 2.1 Full model (ResNet50+FPN backbone)
179
**The model**
180
<p align="center">
181
<img src="https://github.com/AlexTS1980/COVID-CT-Mask-Net/blob/master/plots/covid-ct-mask-net.png" width="900" height="200" align="center"/>
182
</p>
183
184
**Classification module *S***
185
<p align="center">
186
<img src="https://github.com/AlexTS1980/COVID-CT-Mask-Net/blob/master/plots/s_module_final.png" width="700" height="400" align="center"/>
187
</p>
188
189
I reimplemented torchvision's detection library(https://github.com/pytorch/vision/tree/master/torchvision/models/detection) in `/models/mask_net/` with the classification module **s2_new** (**S** in the paper) and other hacks that convert Mask R-CNN into a classification model.
190
First, download and unpack the CNCB dataset: (http://ncov-ai.big.ac.cn/download), a total of over 100K CT scans. The COVIDx-CT split we used is here: https://github.com/haydengunraj/COVIDNet-CT/blob/master/docs/dataset.md). To extract the COVID-19, pneumonia and normal scans, follow the instructions in the link to COVIDx-CT. You don't need to do any image preprocessing as inthe COVIDNet-CT model. We used the full validation and test split, and a small share of the training data, our sample is in `train_split_classification.txt`. To follow the convention used in the other two datsets, we set Class 0: Control, Class 1: Normal Pneumonia, Class 2: COVID. Thus the dataset interface `datasets/dataset_classification.py` extracts the labels from the file names. The convention for the names must be `[Class]_[PatientID]_[ScanNum]_[SliceNum].png`. To train the classifier, copy the images following this convention into a separate directory, e.g. `train_small`.
191
192
### 2.2 Lightweight Models (Truncated ResNet18/34+FPN backbone)
193
194
I implemented two backbones, ResNet18 and ResNet34, both with a single FPN module, and two truncations: the last block or two last blocks. 
195
196
Backbone model:
197
<p align="center">
198
<img src="https://github.com/AlexTS1980/COVID-CT-Mask-Net/blob/master/plots/resnet18.png" width="800" height="250" align="center"/>
199
</p>
200
201
Here's the full size comparison:
202
203
| Model | Total #parameters| #Trainable parameters|
204
|:-:    |:-:    |:-:    |   
205
| **Lightweight model, 5 blocks, ResNet34+FPN**     | 24.86M|0.6M|
206
| **Lightweight model, 4 blocks, ResNet34+FPN**     | 11.74M|0.6M|
207
| **Lightweight model, 3 blocks, ResNet34+FPN**     | 4.92M|0.6M|
208
| **Lightweight model, 5 blocks, ResNet18+FPN**     | 14.75M|0.6M|
209
| **Lightweight model, 4 blocks, ResNet18+FPN**     | 6.35M|0.6M|
210
| **Lightweight model, 3 blocks, ResNet18+FPN**     | 4.25M|0.6M|
211
|**Full model, 5 blocks, ResNet50+FPN (4 layers)**|34.14M|2.36M|
212
213
214
## 3. Models' hyperparameters
215
216
There are two groups of hyperparameters: training (learning rate, weight regularization, optimizer, etc) and Mask R-CNN hyperparameters (Non-max suppression threshold, RPN and RoI batch size, RPN output, RoI score threshold, etc). The ones in the training scripts are the ones we used to get the models in the paper and the results. For the segmentation model you can use any you want, but for COVID-CT-Mask-Net the RoI score threshold (`box_score_thresh`) must be negative (e.g. `-0.01`), because otherwise not all box predictions (`box_detections_per_img`) will be accepted, and the classification module **S** will not get the batch of the right size, hence you will get a tensor mismatch error.
217
218
[Update 22/10/20:] Also, our re-implementation of torchvision's Mask R-CNN has a hack that allows maintaining the same batch size regardless of the pre-set `box_score_thresh`. 
219
220
For any questions, contact Alex Ter-Sarkisov: alex.ter-sarkisov@city.ac.uk