|
a |
|
b/README.md |
|
|
1 |
# Prognostic Imaging Biomarker Discovery in Survival Analysis for Idiopathic Pulmonary Fibrosis |
|
|
2 |
|
|
|
3 |
Pytorch implementation of MICCAI 2022 paper. |
|
|
4 |
|
|
|
5 |
Imaging biomarkers derived from medical images play an |
|
|
6 |
important role in diagnosis, prognosis, and therapy response assessment. |
|
|
7 |
Developing prognostic imaging biomarkers which can achieve reliable |
|
|
8 |
survival prediction is essential for prognostication across various diseases |
|
|
9 |
and imaging modalities. In this work, we propose a method for discov- |
|
|
10 |
ering patch-level imaging patterns which we then use to predict mor- |
|
|
11 |
tality risk and identify prognostic biomarkers. Specifically, a contrastive |
|
|
12 |
learning model is first trained on patches to learn patch representations, |
|
|
13 |
followed by a clustering method to group similar underlying imaging |
|
|
14 |
patterns. The entire medical image can be thus represented by a long |
|
|
15 |
sequence of patch representations and their cluster assignments. Then a |
|
|
16 |
memory-efficient clustering Vision Transformer is proposed to aggregate |
|
|
17 |
all the patches to predict mortality risk of patients and identify high- |
|
|
18 |
risk patterns. To demonstrate the effectiveness and generalizability of |
|
|
19 |
our model, we test the survival prediction performance of our method on |
|
|
20 |
two sets of patients with idiopathic pulmonary fibrosis (IPF), a chronic, |
|
|
21 |
progressive, and life-threatening interstitial pneumonia of unknown eti- |
|
|
22 |
ology. Moreover, by comparing the high-risk imaging patterns extracted |
|
|
23 |
by our model with existing imaging patterns utilised in clinical practice, |
|
|
24 |
we can identify a novel biomarker that may help clinicians improve risk |
|
|
25 |
stratification of IPF patients. |
|
|
26 |
|
|
|
27 |
 |
|
|
28 |
|
|
|
29 |
## Requirements |
|
|
30 |
* python = 3.8.10 |
|
|
31 |
* pytorch = 1.7.1 |
|
|
32 |
* torchvision = 0.8.2 |
|
|
33 |
* CUDA 11.2 |
|
|
34 |
|
|
|
35 |
## Setup |
|
|
36 |
|
|
|
37 |
# representation learning |
|
|
38 |
For representation learning, the data is organized in webdataset format, which make it easier to write I/O pipelines for large datasets. Within the .tar file, a series of training samples are stored as .npy files. The sample follows the format |
|
|
39 |
|
|
|
40 |
``` |
|
|
41 |
samples.tar |
|
|
42 |
| |
|
|
43 |
├── 0.npy # Random location (x1,y1,z) within slides |
|
|
44 |
| ├── image # (64x64x2) Crops of CT scans at the location (x1,y1,z-1) and (x1,y1,z+1) |
|
|
45 |
| ├── image_he: # (64x64x1) Crop of CT scans at the location (x1,y1,z) |
|
|
46 |
| ├── image_pairs: # (64x64x2) Crops of CT scans at the location (x2,y2,z-1) and (x2,y2,z+1)overlapping with "image" crops |
|
|
47 |
| ├── image_pairs_he: # (64x64x1) Crop of CT scans at the location (x2,y2,z) |
|
|
48 |
| └── idx_overall: # (int) Used intervally when developping the alogithm |
|
|
49 |
| |
|
|
50 |
├── 1.npy # Another location |
|
|
51 |
| └── ... |
|
|
52 |
| |
|
|
53 |
└── 2.npy # Another location |
|
|
54 |
| └── ... |
|
|
55 |
... |
|
|
56 |
|
|
|
57 |
``` |
|
|
58 |
First, you can go into the folder /DnR run the training for representation learning using the command. |
|
|
59 |
|
|
|
60 |
```bash |
|
|
61 |
python run_dnr.py --phase train. |
|
|
62 |
``` |
|
|
63 |
|
|
|
64 |
After getting the trained model, you can get patch representations for all the patch by using the command. |
|
|
65 |
|
|
|
66 |
```bash |
|
|
67 |
python run_dnr.py --phase test --trained_model './trainedModels/model_3_24.pth/'. |
|
|
68 |
``` |
|
|
69 |
# clustering |
|
|
70 |
Then using SphericalKMeans in spherecluster package to cluster all the patch embeddings. |
|
|
71 |
|
|
|
72 |
# Mortality prediction bia clustering ViT |
|
|
73 |
Finally, the patch embeddings and their cluster assignments are fed into the clustering ViT to predict mortality risk. For clustering ViT, the data follows the format |
|
|
74 |
|
|
|
75 |
``` |
|
|
76 |
CTscans.npy |
|
|
77 |
| |
|
|
78 |
├── patientEmbedding # (n x d) Embeddings for all patches within the CT scans generated from DnR, n is the number of patches, and d is dimention of embedding. |
|
|
79 |
├── position # (n x 3) Cordinates for all patches in original CT scans |
|
|
80 |
├── cluster # (n x 1) Cluster assignments for all patches generated from KMeans |
|
|
81 |
├── Dead # 1 means the event is observed, 0 means censored |
|
|
82 |
├── FollowUpTime # The time between CT scans date and the date of event or date of censored. |
|
|
83 |
``` |
|
|
84 |
|
|
|
85 |
You can go into the folder install the library by running the command. |
|
|
86 |
|
|
|
87 |
```bash |
|
|
88 |
python setup.py install |
|
|
89 |
``` |
|
|
90 |
|
|
|
91 |
Move .so files to models/extensions, and then train the model by running the command. |
|
|
92 |
|
|
|
93 |
```bash |
|
|
94 |
python main.py |
|
|
95 |
--lr_drop 100 |
|
|
96 |
--epochs 100 |
|
|
97 |
--group_Q |
|
|
98 |
--batch_size 4 |
|
|
99 |
--dropout 0.1 |
|
|
100 |
--sequence_len 15000 |
|
|
101 |
--weight_decay 0.0001 |
|
|
102 |
--seq_pool |
|
|
103 |
--dataDir /dataset |
|
|
104 |
--lr 2e-5 |
|
|
105 |
--mixUp |
|
|
106 |
--SAM |
|
|
107 |
--withEmbeddingPreNorm |
|
|
108 |
--max_num_cluster 64 |
|
|
109 |
``` |
|
|
110 |
## Acknolegdement |
|
|
111 |
The project are borrowed heavily from DETR, and End-to-end object detection with adaptive clustering transformer. |