a/README.md b/README.md
1
# ML-Quiz-XRay-ReportGeneration
1
# ML-Quiz-XRay-ReportGeneration
2
Finetune a Multimodal Model for X-Ray Radiology Report Generation
2
Finetune a Multimodal Model for X-Ray Radiology Report Generation
3
3
4
This repository finetunes the LLaVA (Large Language and Vision Assistant) on IU X-Ray dataset. The `XRay-ReportGeneration.pdf` defines two Tasks:
4
This repository finetunes the LLaVA (Large Language and Vision Assistant) on IU X-Ray dataset. The `XRay-ReportGeneration.pdf` defines two Tasks:
5
5
6
- Task 1. Prompt engineering: reorganize the X-Ray report findings into predefined anatomical regions
6
- Task 1. Prompt engineering: reorganize the X-Ray report findings into predefined anatomical regions
7
- Task 2. Efficient Model Fine-tuning of LLaVA on the provided training set for report generation using parameter efficient fine-tuning methods (e.g., LoRA).
7
- Task 2. Efficient Model Fine-tuning of LLaVA on the provided training set for report generation using parameter efficient fine-tuning methods (e.g., LoRA).
8
8
9
9
10
## Environments and Requirements
10
## Environments and Requirements
11
11
12
- [Google Colab](https://colab.research.google.com/)
12
- [Google Colab](https://colab.research.google.com/)
13
- A100 GPU and 100 compute units
13
- A100 GPU and 100 compute units
14
14
15
Setup and Install:
15
Setup and Install:
16
16
17
1. Mount to your google drive account
17
1. Mount to your google drive account
18
```setup
18
```setup
19
from google.colab import drive
19
from google.colab import drive
20
drive.mount('/content/drive')
20
drive.mount('/content/drive')
21
```
21
```
22
22
23
2. Load Quiz repo from Github with finetuned model checkpoints
23
2. Load Quiz repo from Github with finetuned model checkpoints
24
```setup
24
```setup
25
!git clone https://github.com/Shen16/ML-Quiz-XRay-ReportGeneration.git
25
!git clone https://github.com/Shen16/ML-Quiz-XRay-ReportGeneration.git
26
%cd ML-Quiz-XRay-ReportGeneration
26
%cd ML-Quiz-XRay-ReportGeneration
27
```
27
```
28
28
29
3. Load `LLaVa` repo and install the repository in editable mode
29
3. Load `LLaVa` repo and install the repository in editable mode
30
```setup
30
```setup
31
!git clone https://github.com/haotian-liu/LLaVA.git
31
!git clone https://github.com/haotian-liu/LLaVA.git
32
!cd LLaVA && pip install --upgrade pip && pip install -e .
32
!cd LLaVA && pip install --upgrade pip && pip install -e .
33
```
33
```
34
34
35
4. Install flash attention and DeepSpeed
35
4. Install flash attention and DeepSpeed
36
```setup
36
```setup
37
!cd LLaVA && pip install -e ".[train]"
37
!cd LLaVA && pip install -e ".[train]"
38
!pip install flash-attn --no-build-isolation
38
!pip install flash-attn --no-build-isolation
39
!pip install deepspeed
39
!pip install deepspeed
40
```
40
```
41
41
42
5. Install other packages
42
5. Install other packages
43
```setup
43
```setup
44
!pip install -r requirements.txt
44
!pip install -r requirements.txt
45
```
45
```
46
46
47
47
48
## Dataset
48
## Dataset
49
49
50
- Data was obtained from the [IU-X-ray](https://paperswithcode.com/dataset/iu-x-ray) datatset. 
50
- Data was obtained from the [IU-X-ray](https://paperswithcode.com/dataset/iu-x-ray) datatset. 
51
- Download the raw dataset from this [link](https://drive.google.com/file/d/1lBEpxrmwBkgVTZGZ92mCu0qcoq58-yMI/view?usp=sharing)
51
- Download the raw dataset from this [link](https://drive.google.com/file/d/1lBEpxrmwBkgVTZGZ92mCu0qcoq58-yMI/view?usp=sharing)
52
52
53
## Preprocessing
53
## Preprocessing
54
54
55
Description of preprocessing method:
55
Description of preprocessing method:
56
- `prompt.ipynb` notebook was used to generate the categorized reports on the validation split to complete Task 1. The output is saved in `annotation.json` file.
56
- `prompt.ipynb` notebook was used to generate the categorized reports on the validation split to complete Task 1. The output is saved in `annotation.json` file.
57
- Processed the dataset to create single image and report pairs for each patient. The `0.png` image was used only. 
57
- Processed the dataset to create single image and report pairs for each patient. The `0.png` image was used only. 
58
58
59
*Note: Did not load multiple images for a single patient for GPU and memory limitations).*
59
*Note: Did not load multiple images for a single patient for GPU and memory limitations).*
60
60
61
61
62
Running the data preprocessing code:
62
Running the data preprocessing code:
63
63
64
1. Run the script to generate the report for Task 1. Get your OpenAI API token and replace `<openAI_API_key>` with it:
64
1. Run the script to generate the report for Task 1. Get your OpenAI API token and replace `<openAI_API_key>` with it:
65
```python
65
```python
66
!python prompt.py -i 'data/annotation_quiz_all.json' -o 'data/annotation.json' -k "<openAI_API_key>"
66
!python prompt.py -i 'data/annotation_quiz_all.json' -o 'data/annotation.json' -k "<openAI_API_key>"
67
```
67
```
68
*Alternatively you can run the `prompt.ipynb` notebook in colab*
68
*Alternatively you can run the `prompt.ipynb` notebook in colab*
69
69
70
2. Generate image-report pairs for each patient and generate `dataset.json` file in the following structure for training:
70
2. Generate image-report pairs for each patient and generate `dataset.json` file in the following structure for training:
71
71
72
```json
72
```json
73
[
73
[
74
    {
74
    {
75
        "id": "360691e3-610f-4d6d-9f71-99672524eb89",
75
        "id": "360691e3-610f-4d6d-9f71-99672524eb89",
76
        "image": "360691e3-610f-4d6d-9f71-99672524eb89.png",
76
        "image": "360691e3-610f-4d6d-9f71-99672524eb89.png",
77
        "conversations": [
77
        "conversations": [
78
            {
78
            {
79
                "from": "human",
79
                "from": "human",
80
                "value": "Please describe the findings in the X-ray."
80
                "value": "Please describe the findings in the X-ray."
81
            },
81
            },
82
            {
82
            {
83
                "from": "gpt",
83
                "from": "gpt",
84
                "value": "<s_bone>Degenerative changes are present in the spine.</s_bone><s_heart>Heart size and pulmonary vascularity appear within normal limits.</s_heart><s_lung>Lungs are free of focal airspace disease. No pneumothorax or pleural effusion is seen.</s_lung><s_mediastinal></s_mediastinal><s_others>A large hiatal hernia is noted.</s_others>"
84
                "value": "<s_bone>Degenerative changes are present in the spine.</s_bone><s_heart>Heart size and pulmonary vascularity appear within normal limits.</s_heart><s_lung>Lungs are free of focal airspace disease. No pneumothorax or pleural effusion is seen.</s_lung><s_mediastinal></s_mediastinal><s_others>A large hiatal hernia is noted.</s_others>"
85
            }
85
            }
86
        ]
86
        ]
87
    },
87
    },
88
    {
88
    {
89
        "id": "511d0a14-3b13-4814-8658-baf1c47b4188",
89
        "id": "511d0a14-3b13-4814-8658-baf1c47b4188",
90
        "image": "511d0a14-3b13-4814-8658-baf1c47b4188.png",
90
        "image": "511d0a14-3b13-4814-8658-baf1c47b4188.png",
91
        "conversations": [
91
        "conversations": [
92
            {
92
            {
93
                "from": "human",
93
                "from": "human",
94
                "value": "Please describe the findings in the X-ray."
94
                "value": "Please describe the findings in the X-ray."
95
            },
95
            },
96
            {
96
            {
97
                "from": "gpt",
97
                "from": "gpt",
98
                "value": "<s_bone>Bony structures are intact.</s_bone><s_heart>Cardiac contours are within normal limits.</s_heart><s_lung>Lungs are clear.</s_lung><s_mediastinal>Mediastinal contours are within normal limits.</s_mediastinal><s_others></s_others>"
98
                "value": "<s_bone>Bony structures are intact.</s_bone><s_heart>Cardiac contours are within normal limits.</s_heart><s_lung>Lungs are clear.</s_lung><s_mediastinal>Mediastinal contours are within normal limits.</s_mediastinal><s_others></s_others>"
99
            }
99
            }
100
        ]
100
        ]
101
    },
101
    },
102
102
103
    ...
103
    ...
104
]
104
]
105
```
105
```
106
106
107
Each image is assigned a unique id. Organize the data as follows in `./ML-Quiz-XRay-ReportGeneration`:
107
Each image is assigned a unique id. Organize the data as follows in `./ML-Quiz-XRay-ReportGeneration`:
108
```
108
```
109
├── dataset_train
109
├── dataset_train
110
│   └── images
110
│   └── images
111
│       └── 00a7a45c-af6e-4f5b-b305-595620ae4deb.png
111
│       └── 00a7a45c-af6e-4f5b-b305-595620ae4deb.png
112
│       └── ...
112
│       └── ...
113
│   └── train
113
│   └── train
114
│       └── train_dataset.json
114
│       └── train_dataset.json
115
115

116
├── dataset_test
116
├── dataset_test
117
│   └── images
117
│   └── images
118
│       └── 00a7a45c-af6e-4f5b-b305-595620ae9ner.png
118
│       └── 00a7a45c-af6e-4f5b-b305-595620ae9ner.png
119
│       └── ...
119
│       └── ...
120
│   └── test
120
│   └── test
121
│       └── test_dataset.json
121
│       └── test_dataset.json
122
122

123
├── dataset_val
123
├── dataset_val
124
    └── images
124
    └── images
125
        └── 00a7a45c-af6e-4f5b-b305-595620ae0reb.png
125
        └── 00a7a45c-af6e-4f5b-b305-595620ae0reb.png
126
        └── ...
126
        └── ...
127
    └── val
127
    └── val
128
        └── val_dataset.json
128
        └── val_dataset.json
129
```
129
```
130
130
131
131
132
Run this script to generate the files for each split:
132
Run this script to generate the files for each split:
133
```python
133
```python
134
!python preprocess_data.py -s 'train'
134
!python preprocess_data.py -s 'train'
135
!python preprocess_data.py -s 'val'
135
!python preprocess_data.py -s 'val'
136
!python preprocess_data.py -s 'test'
136
!python preprocess_data.py -s 'test'
137
```
137
```
138
138
139
139
140
## Training
140
## Training
141
141
142
To train the model, run this notebook in Google Colab:
142
To train the model, run this notebook in Google Colab:
143
```train
143
```train
144
python finetune.ipynb
144
python finetune.ipynb
145
```
145
```
146
146
147
1. Set environment variable to reduce memory fragmentation issues
147
1. Set environment variable to reduce memory fragmentation issues
148
```
148
```
149
import os
149
import os
150
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
150
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
151
```
151
```
152
152
153
2. Run the train_mem.py script from LLaVA repo to finetune model. Used LoRa adapters for parameter efficient finetuning. Used liuhaotian/llava-v1.5-7b base model with 1 epoch (to stay within compute limit). Finetuned model weights saved to ./checkpoints/llava-v1.5-7b-task-lora folder.
153
2. Run the train_mem.py script from LLaVA repo to finetune model. Used LoRa adapters for parameter efficient finetuning. Used liuhaotian/llava-v1.5-7b base model with 1 epoch (to stay within compute limit). Finetuned model weights saved to ./checkpoints/llava-v1.5-7b-task-lora folder.
154
```
154
```
155
!deepspeed LLaVA/llava/train/train_mem.py \
155
!deepspeed LLaVA/llava/train/train_mem.py \
156
    --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
156
    --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
157
    --deepspeed LLaVA/scripts/zero3.json \
157
    --deepspeed LLaVA/scripts/zero3.json \
158
    --model_name_or_path liuhaotian/llava-v1.5-7b \
158
    --model_name_or_path liuhaotian/llava-v1.5-7b \
159
    --version v1 \
159
    --version v1 \
160
    --data_path ./dataset_train/train/train_dataset.json \
160
    --data_path ./dataset_train/train/train_dataset.json \
161
    --image_folder ./dataset_train/images \
161
    --image_folder ./dataset_train/images \
162
    --vision_tower openai/clip-vit-large-patch14-336 \
162
    --vision_tower openai/clip-vit-large-patch14-336 \
163
    --mm_projector_type mlp2x_gelu \
163
    --mm_projector_type mlp2x_gelu \
164
    --mm_vision_select_layer -2 \
164
    --mm_vision_select_layer -2 \
165
    --mm_use_im_start_end False \
165
    --mm_use_im_start_end False \
166
    --mm_use_im_patch_token False \
166
    --mm_use_im_patch_token False \
167
    --image_aspect_ratio pad \
167
    --image_aspect_ratio pad \
168
    --group_by_modality_length True \
168
    --group_by_modality_length True \
169
    --bf16 True \
169
    --bf16 True \
170
    --output_dir ./checkpoints/llava-v1.5-7b-task-lora \
170
    --output_dir ./checkpoints/llava-v1.5-7b-task-lora \
171
    --num_train_epochs 1 \
171
    --num_train_epochs 1 \
172
    --per_device_train_batch_size 16 \
172
    --per_device_train_batch_size 16 \
173
    --per_device_eval_batch_size 4 \
173
    --per_device_eval_batch_size 4 \
174
    --gradient_accumulation_steps 1 \
174
    --gradient_accumulation_steps 1 \
175
    --evaluation_strategy "no" \
175
    --evaluation_strategy "no" \
176
    --save_strategy "steps" \
176
    --save_strategy "steps" \
177
    --save_steps 50000 \
177
    --save_steps 50000 \
178
    --save_total_limit 1 \
178
    --save_total_limit 1 \
179
    --learning_rate 2e-4 \
179
    --learning_rate 2e-4 \
180
    --weight_decay 0. \
180
    --weight_decay 0. \
181
    --warmup_ratio 0.03 \
181
    --warmup_ratio 0.03 \
182
    --lr_scheduler_type "cosine" \
182
    --lr_scheduler_type "cosine" \
183
    --logging_steps 1 \
183
    --logging_steps 1 \
184
    --tf32 True \
184
    --tf32 True \
185
    --model_max_length 2048 \
185
    --model_max_length 2048 \
186
    --gradient_checkpointing True \
186
    --gradient_checkpointing True \
187
    --dataloader_num_workers 4 \
187
    --dataloader_num_workers 4 \
188
    --lazy_preprocess True \
188
    --lazy_preprocess True \
189
    --report_to tensorboard
189
    --report_to tensorboard
190
```
190
```
191
191
192
192
193
## Trained Models
193
## Trained Models
194
194
195
You can download trained model here:
195
You can download trained model here:
196
196
197
1. Download the model checkpoint folder from the [zip file](https://drive.gogithubZipfile.git). 
197
1. Download the model checkpoint folder from the [zip file](https://drive.gogithubZipfile.git). 
198
198
199
2. To load the model excute the following code:
199
2. To load the model excute the following code:
200
```python
200
```python
201
#merge the LoRA weights with the full base model
201
#merge the LoRA weights with the full base model
202
!python LLaVA/scripts/merge_lora_weights.py --model-path checkpoints/llava-v1.5-7b-task-lora --model-base liuhaotian/llava-v1.5-7b --save-model-path llava-ftmodel
202
!python LLaVA/scripts/merge_lora_weights.py --model-path checkpoints/llava-v1.5-7b-task-lora --model-base liuhaotian/llava-v1.5-7b --save-model-path llava-ftmodel
203
```
203
```
204
*Make sure the `llava-ftmodel` is in the `./ML-Quiz-XRay-ReportGeneration` directory.*
204
*Make sure the `llava-ftmodel` is in the `./ML-Quiz-XRay-ReportGeneration` directory.*
205
205
206
206
207
## Inference
207
## Inference
208
208
209
1. To infer the testing cases, run the following code on a Test sample image:
209
1. To infer the testing cases, run the following code on a Test sample image:
210
```python
210
```python
211
!python -m llava.serve.cli \
211
!python -m llava.serve.cli \
212
  --model-path llava-ftmodel \
212
  --model-path llava-ftmodel \
213
  --image-file "/content/drive/My Drive/ML-Quiz-XRay-ReportGeneration/dataset_test/images/cf33da4a-49f3-4dd1-8e5b-038d2637751f.png"
213
  --image-file "/content/drive/My Drive/ML-Quiz-XRay-ReportGeneration/dataset_test/images/cf33da4a-49f3-4dd1-8e5b-038d2637751f.png"
214
```
214
```
215
*Note: Replace image id "cf33da4a-49f3-4dd1-8e5b-038d2637751f.png" with exisiting id*
215
*Note: Replace image id "cf33da4a-49f3-4dd1-8e5b-038d2637751f.png" with exisiting id*
216
216
217
217
218
2. Generate predictions on Test set:
218
2. Generate predictions on Test set:
219
```python
219
```python
220
!python -m llava.eval.model_vqa_science \
220
!python -m llava.eval.model_vqa_science \
221
    --model-path llava-ftmodel \
221
    --model-path llava-ftmodel \
222
    --question-file ./dataset_test/test/test_dataset.json \
222
    --question-file ./dataset_test/test/test_dataset.json \
223
    --image-folder ./dataset_test/images \
223
    --image-folder ./dataset_test/images \
224
    --answers-file ./dataset_test/answers/llava-v1ft.5-7b.jsonl \
224
    --answers-file ./dataset_test/answers/llava-v1ft.5-7b.jsonl \
225
    --single-pred-prompt \
225
    --single-pred-prompt \
226
    --temperature 0 \
226
    --temperature 0 \
227
    --conv-mode vicuna_v1
227
    --conv-mode vicuna_v1
228
```
228
```
229
229
230
3. Generate predictions on Val set:
230
3. Generate predictions on Val set:
231
```python
231
```python
232
!python -m llava.eval.model_vqa_science \
232
!python -m llava.eval.model_vqa_science \
233
    --model-path llava-ftmodel \
233
    --model-path llava-ftmodel \
234
    --question-file ./dataset_val/val/val_dataset.json \
234
    --question-file ./dataset_val/val/val_dataset.json \
235
    --image-folder ./dataset_val/images \
235
    --image-folder ./dataset_val/images \
236
    --answers-file ./dataset_val/answers/llava-v1ft.5-7b.jsonl \
236
    --answers-file ./dataset_val/answers/llava-v1ft.5-7b.jsonl \
237
    --single-pred-prompt \
237
    --single-pred-prompt \
238
    --temperature 0 \
238
    --temperature 0 \
239
    --conv-mode vicuna_v1
239
    --conv-mode vicuna_v1
240
```
240
```
241
*Note: Was not able to evalute on validation set as I exhausted my colab compute units (requires A100 GPU)*
241
*Note: Was not able to evalute on validation set as I exhausted my colab compute units (requires A100 GPU)*
242
242
243
243
244
244
245
## Evaluation
245
## Evaluation
246
246
247
1. Clone the GREEN repo:
247
1. Clone the GREEN repo:
248
```python
248
```python
249
!git clone https://github.com/Stanford-AIMI/GREEN.git
249
!git clone https://github.com/Stanford-AIMI/GREEN.git
250
```
250
```
251
251
252
2. Install required packages. Restart kernel afterwards.:
252
2. Install required packages. Restart kernel afterwards.:
253
```python
253
```python
254
 # run and then restart kernel to use packages
254
 # run and then restart kernel to use packages
255
%cd GREEN
255
%cd GREEN
256
!pip install -e . 
256
!pip install -e . 
257
```
257
```
258
258
259
3. Import green3.py script which will save both GREEN summary and result dataframe:
259
3. Import green3.py script which will save both GREEN summary and result dataframe:
260
```python
260
```python
261
# import libraries
261
# import libraries
262
from src import green # import green.py
262
from src import green # import green.py
263
from src.green3 import compute # modified code (green3.py) to save both GREEN summary and result_df
263
from src.green3 import compute # modified code (green3.py) to save both GREEN summary and result_df
264
```
264
```
265
265
266
4. To compute the GREEN evaluation metrics on test set, run the notebook:
266
4. To compute the GREEN evaluation metrics on test set, run the notebook:
267
```python
267
```python
268
GREEN_eval.ipynb
268
GREEN_eval.ipynb
269
```
269
```
270
270
271
271
272
## Results
272
## Results
273
273
274
Our method achieves the following performance on XRay-ReportGeneration:
274
Our method achieves the following performance on XRay-ReportGeneration:
275
275
276
| Index | GREEN Score                                                    |      Bone      |     Heart      |     Lung       | Mediastinal   |
276
| Index | GREEN Score                                                    |      Bone      |     Heart      |     Lung       | Mediastinal   |
277
|-------|----------------------------------------------------------------|:--------------:|:--------------:|:--------------:|:-------------:|
277
|-------|----------------------------------------------------------------|:--------------:|:--------------:|:--------------:|:-------------:|
278
| 0     | Green average                                                  |  0.3253968254  |  0.8163265306  |  0.6737001944  | 0.5939625850  |
278
| 0     | Green average                                                  |  0.3253968254  |  0.8163265306  |  0.6737001944  | 0.5939625850  |
279
| 1     | Standard variation                                             |  0.4635559317  |  0.3657609376  |  0.3195053647  | 0.4767040097  |
279
| 1     | Standard variation                                             |  0.4635559317  |  0.3657609376  |  0.3195053647  | 0.4767040097  |
280
| 2     | (a) False report of a finding in the candidate                 |  0.7959183673  |  0.9081632653  |  0.5425170068  | 0.8469387755  |
280
| 2     | (a) False report of a finding in the candidate                 |  0.7959183673  |  0.9081632653  |  0.5425170068  | 0.8469387755  |
281
| 3     | (b) Missing a finding present in the reference                 |  0.9217687075  |  0.9098639456  |  0.8095238095  | 0.7942176871  |
281
| 3     | (b) Missing a finding present in the reference                 |  0.9217687075  |  0.9098639456  |  0.8095238095  | 0.7942176871  |
282
| 4     | (c) Misidentification of a finding's anatomic location/position|  1.0000000000  |  0.9897959184  |  0.9982993197  | 0.9829931973  |
282
| 4     | (c) Misidentification of a finding's anatomic location/position|  1.0000000000  |  0.9897959184  |  0.9982993197  | 0.9829931973  |
283
| 5     | (d) Misassessment of the severity of a finding                 |  1.0000000000  |  0.9778911565  |  0.9931972789  | 0.9948979592  |
283
| 5     | (d) Misassessment of the severity of a finding                 |  1.0000000000  |  0.9778911565  |  0.9931972789  | 0.9948979592  |
284
| 6     | (e) Mentioning a comparison that isn't in the reference        |  0.9982993197  |  0.9965986395  |  0.9880952381  | 0.9965986395  |
284
| 6     | (e) Mentioning a comparison that isn't in the reference        |  0.9982993197  |  0.9965986395  |  0.9880952381  | 0.9965986395  |
285
| 7     | (f) Omitting a comparison detailing a change from a prior study|  1.0000000000  |  1.0000000000  |  1.0000000000  | 1.0000000000  |
285
| 7     | (f) Omitting a comparison detailing a change from a prior study|  1.0000000000  |  1.0000000000  |  1.0000000000  | 1.0000000000  |
286
286
287
287
288
288
289
![GREEN Score Evaluation](GREEN_Score_Test.png)
289
![GREEN Score Evaluation](https://github.com/Shen16/ML-Quiz-XRay-ReportGeneration/blob/main/GREEN_Score_Test.png?raw=true)
290
290
291
291
292
## Acknowledgement
292
## Acknowledgement
293
293
294
> We thank the contributors of public datasets. 
294
 We thank the contributors of public datasets. 
295
295