Diff of /README.md [000000] .. [248dc9]

Switch to unified view

a b/README.md
1
# Fine-Tuning LLMs for Medical Entity Extraction
2
> #### _Archit | Fall '23 | Duke AIPI 591 (Independent Study in GenAI) Research Project_
3
   
4
5
## Project Overview โญ
6
7
The pharmaceutical industry heavily relies on accurately processing adverse event reports, a task traditionally done manually and prone to inefficiencies. The automation of `Medical Entity Extraction`, particularly `drug names` and `side effects`, is essential for enhancing patient safety and compliance. Current systems, while adept at identifying general information like patient name, age, address and other demographics, struggle with specialized medical terminology, creating a pressing need for more advanced solutions.
8
9
The goal of this project was to explore and implement methods for `fine-tuning` Large Language Models `(LLMs)` such as `Llama2` and `StableLM` for the specific task of extracting medical entities from adverse event report (for now emails only). By fine-tuning these models on a synthetic dataset, derived from drug information on Drugs.com, the aim was to surpass traditional entity recognition methods in accuracy and efficiency. This approach aims to streamline the entity extraction process and enhance the reliability and timeliness of data on drug adverse events, thereby offering potential improvements in medical data analysis practices.
10
11
 
12
## Data Sources ๐Ÿ’พ 
13
14
This project's dataset was built by extracting detailed information about the top 50 most popular drugs from [Drugs.com](https://www.drugs.com), a comprehensive and authoritative online resource for medication information. Drugs.com provides a wide range of data on pharmaceuticals, including drug descriptions, dosages, indications, and primary side effects. This rich source of information was instrumental in developing a robust and accurate synthetic dataset for the project.
15
16
The drugs selected for this study include a wide range of medications known for their prevalence in the market and significance in treatment regimens. These drugs span various therapeutic categories and include:
17
18
* `Psychiatric` drugs like Abilify.
19
* `Immunomodulators` such as Infliximab, Rituximab, Etanercept, Humira, and Enbrel.
20
* `Gastrointestinal` medications like Nexium, Prevacid, Prilosec, and Protonix.
21
* `Cholesterol`-lowering agents including Crestor, Lipitor, Zocor, and Vytorin.
22
* `Diabetes` medications such as Victoza, Byetta, Januvia, and Onglyza.
23
* `Respiratory` treatments like Advair, Symbicort, Spiriva, and Singulair.
24
* `Erectile dysfunction` drugs including Cialis, Viagra, Levitra, and Staxyn.
25
* `Other Medications` like AndroGel, Prezista, Doxycycline, Cymbalta, Neupogen, Epogen, Aranesp, Neulasta, Lunesta, Ambien, Provigil, Nuvigil, Metoprolol, Lisinopril, Amlodipine, Atorvastatin, Zoloft, Lexapro, Prozac, Celexa, and Atripla.
26
27
Each of these drugs was carefully chosen to provide a comprehensive view of the different types of medical entities that the LLMs would need to identify and extract from adverse event reports.  
28
29
 
30
## Data Processing ๐Ÿ“  
31
32
### **Scraping drug information data from the Drugs.com**
33
34
For this project, crucial drug information was scraped from Drugs.com. Each drug's dedicated webpage provides detailed information which varies in structure, making the scraping process complex.
35
36
To effectively handle this complexity, a Python script utilizing the [BeautifulSoup](https://www.crummy.com/software/BeautifulSoup/bs4/doc/) library was employed. This script parsed the HTML content of each webpage, targeting specific sections relevant to our study, such as drug uses and side effects. For text conversion, the [html2text](https://pypi.org/project/html2text/) package was used, allowing the extraction of clean and readable text data from the HTML content.
37
38
The python script  to scrape text can be found in the `scripts` folder and can be run as follows:
39
40
**1. Create a new conda environment and activate it:** 
41
```
42
conda create --name llms python=3.10.13
43
conda activate llms
44
```
45
**2. Install python package requirements:** 
46
```
47
pip install -r requirements.txt 
48
```
49
**3. Run the web scraping script:** 
50
```
51
python scripts/scrape_drugs_data.py
52
```  
53
54
     
55
### **Synthetic Dataset Generation for Fine-Tuning**
56
57
Using the scraped drug information, synthetic `Adverse Event Reports (emails)` were generated. These emails simulate real-world data while ensuring that no real patient data or personally identifiable information was used. The generation process was carried out using prompts designed to guide `ChatGPT` in creating realistic and relevant data scenarios for training purposes. The prompt template used can be found in the data folder and is as follows:  
58
```
59
Act as an expert Analyst with 20+ years of experience in Pharma and Healthcare industry. You have to generate Adverse Event Reports in JSON format just like the following example:
60
61
{
62
    "input": "Nicole Moore
63
            moore123nicole@hotmail.com
64
            32 McMurray Court, Columbia, SC 41250
65
            1840105113, United States 
66
            
67
            Relationship to XYZ Pharma Inc.: Patient or Caregiver
68
            Reason for contacting: Adverse Event
69
            
70
            Message: Yes, I have been taking Metroprolol for two years now and with no problem. I recently had my prescription refilled with the same Metoprolol and Iโ€™m having a hard time sleeping at night along with running nose. Did you possibly change something with the pill...possibly different fillers? The pharmacist at CVS didnโ€™t have any information for me. Thank you, Nicole Moore", 
71
    "output": { 
72
                "drug_name":"Metroprolol", 
73
                "adverse_events": ["hard time sleeping at night", "running nose"]
74
            }
75
}
76
77
Now create Adverse Event Reports in a similar way for the Drug - ''' [DRUG NAME] '''
78
79
You have more information about the drug's use and its side effects below - ''' [DRUG SIDE EFFECTS] '''
80
```
81
82
The synthetic training dataset was generated with groundtruth "input" and "output" pairs, preparing them for use in fine-tuning the language models. This included labeling the relevant entities specifically `drug_name` and `adverse_events`. 
83
84
   
85
Following is an example of the generated data:  
86
```
87
{
88
    "input": "Natalie Cooper,\nncooper@example.com\n6789 Birch Street, Denver, CO 80203,\n102-555-6543, United States\n\nRelationship to XYZ Pharma Inc.: Patient\nReason for contacting: Adverse Event\n\nMessage: Hi, after starting Abilify for bipolar disorder, I've noticed that I am experiencing nausea and vomiting. Are these typical reactions? Best, Natalie Cooper",
89
    
90
    "output": "{\"drug_name\": \"Abilify\", \"adverse_events\": [\"nausea\", \"vomiting\"]}"
91
}
92
```  
93
The python script to generate synthetic data can be found in the `scripts` folder. Assuming you are in the same conda environment as the previous step, the python script can be run as follows:
94
95
**1. Create OpenAI and Save in env file:** 
96
```
97
Rename the env.example file to .env and add your OpenAI API key to the file
98
```
99
**2. Run the data generation script for preparing dataset using OpenAI's Chat completion API:** 
100
```
101
python scripts/data-prepare.py 
102
```
103
**3. Run the data aggregation script to prepare train and test splits:** 
104
```
105
python scripts/combine-data.py 
106
```
107
108
These scripts will generate a supervised dataset with `input` and `output` pairs where input is the adverse event email and the output is the extracted entities. The generated data its stored in  `entity-extraction-train-data.json` and `entity-extraction-test-data.json` files in the `data/entity-extraction` folder. We have 700 training samples and ~70 test samples.  
109
110
   
111
## Fine-tuning Large Language Models (LLMs) for Medical Entity Extraction ๐Ÿง     
112
113
In this project, two Large Language Models (LLMs), `Llama2` and `StableLM`, were fine-tuned using techniques such as `Parameter Efficient Fine-Tuning (PEFT)`, specifically through `Adapter V2` and `LoRA (Low-Rank Adaptation)` methods. PEFT techniques allow for the modification of large models without having to retrain all the parameters, making the fine-tuning process more efficient and resource-friendly. This approach is particularly valuable for tasks that require domain-specific adaptations without losing the broad contextual knowledge the models already possess.
114
115
The fine-tuning assess and compares the effectiveness in enhancing the models' performance for medical entity extraction. These approaches were aimed to balance efficiency and precision, ensuring that the models could accurately identify and extract relevant medical information from complex textual data. 
116
117
### **Downloading Pre-trained LLMs**
118
Use the following steps to download the pre-trained LLMs from HuggingFace and convert them to a LIT-GPT checkpoint. The checkpoints are stored in the `checkpoints` folder.
119
**1. Download the pre-trained LLMs from HuggingFace:** 
120
```
121
python scripts/download.py --repo_id stabilityai/stablelm-base-alpha-3b
122
```  
123
```
124
python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf --access_token your_hf_token
125
```
126
**2. Convert the HuggingFace checkpoint to a LIT-GPT checkpoint:** 
127
```
128
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/stabilityai/stablelm-base-alpha-3b
129
```
130
```
131
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf
132
```
133
134
### **Approach 1: Fine-tuning using LoRA Parameter Efficient Fine-Tuning (PEFT)**
135
136
LoRA focuses on updating the weight matrices of the pre-trained model through low-rank matrix decomposition. By altering only a small subset of the model's weights, LoRA achieves fine-tuning with minimal updates, maintaining the model's overall structure and pre-trained knowledge while adapting it to specific tasks.
137
138
In this approach, only a limited set of weights are fine-tuned on the Synthetic Medical Entity Dataset we generated. The hyperparameters used are as follows:
139
- **Model:** Llama-2-7b or stable-lm-3b
140
- **Batch Size:** 16
141
- **Learning Rate:** 3e-4
142
- **Weight Decay:** 0.01
143
- **Epoch Size:** 700
144
- **Num Epochs:** 5
145
- **Warmup Steps:** 100
146
147
The data is first prepared by tokenizing the text data and converting it into a torch dataset. The model is then fine-tuned on the data using the [Lightning](https://www.pytorchlightning.ai/) framework.
148
149
The model is fine-tuned on 1 GPU (48GB) for 5 epochs. The data preparation and fine-tuning scripts can be found in the `scripts` and `finetune` folders respectively. Assuming you are in the same conda environment as the previous step, the python script can be run as follows:
150
151
**1. Prepare data for fine-tuning (Stable-LM):** 
152
```
153
python scripts/prepare_entity_extraction_data.py --checkpoint_dir checkpoints/stabilityai/stablelm-base-alpha-3b
154
```
155
**2. Run the fine-tuning script (Stable-LM) :** 
156
```
157
python finetune/lora.py --checkpoint_dir checkpoints/stabilityai/stablelm-base-alpha-3b --out_dir out/lora/Stable-LM/entity_extraction
158
```
159
**3. Prepare data for fine-tuning (Llama-2):** 
160
```
161
python scripts/prepare_entity_extraction_data.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-hf
162
```
163
**4. Run the fine-tuning script (Llama-2) :** 
164
```
165
python finetune/lora.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-hf --out_dir out/lora/Llama-2/entity_extraction
166
```
167
168
### **Approach 2: Fine-tuning using Adapter Parameter Efficient Fine-Tuning (PEFT)**
169
170
The Adapter-V2 technique involves inserting small, trainable layers (adapters) into the model's architecture. These adapters learn task-specific features while the majority of the model's original parameters remain frozen. This approach enables efficient fine-tuning, as only a small fraction of the model's parameters are updated, reducing computational requirements and preserving the pre-trained knowledge.
171
172
In this approach, only a limited set of weights are fine-tuned on the Synthetic Medical Entity Dataset we generated. The hyperparameters used are as follows:
173
- **Model:** Llama-2-7b or stable-lm-3b
174
- **Batch Size:** 8
175
- **Learning Rate:** 3e-3
176
- **Weight Decay:** 0.02
177
- **Epoch Size:** 700
178
- **Num Epochs:** 5
179
- **Warmup Steps:** 2 * (epoch_size // micro_batch_size) // devices // gradient_accumulation_iters
180
181
The data is first prepared by tokenizing the text data and converting it into a torch dataset. The model is then fine-tuned on the data using the [Lightning](https://www.pytorchlightning.ai/) framework.
182
183
The model is fine-tuned on 1 GPU (24GB) for 5 epochs. The data preparation and fine-tuning scripts can be found in the `scripts` and `finetune` folders respectively. Assuming you are in the same conda environment as the previous step, the python script can be run as follows:
184
185
**1. Prepare data for fine-tuning (Stable-LM):** 
186
```
187
python scripts/prepare_entity_extraction_data.py --checkpoint_dir checkpoints/stabilityai/stablelm-base-alpha-3b
188
```
189
**2. Run the fine-tuning script (Stable-LM) :** 
190
```
191
python finetune/adapter_v2.py --checkpoint_dir checkpoints/stabilityai/stablelm-base-alpha-3b --out_dir out/adapter/Stable-LM/entity_extraction
192
```
193
**3. Prepare data for fine-tuning (Llama-2):** 
194
```
195
python scripts/prepare_entity_extraction_data.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-hf
196
```
197
**4. Run the fine-tuning script (Llama-2) :** 
198
```
199
python finetune/adapter_v2.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-hf --out_dir out/adapter/Llama-2/entity_extraction
200
```  
201
202
   
203
## Model Inference ๐Ÿงช  
204
Once all the models are fine-tuned, the next step is to generate predictions on the test dataset. The predictions of the fine-tuned models can be generated using the following steps:  
205
206
**1. Generate predictions for models fine-tuned using Adapter PEFT:** 
207
```
208
python generate/inference_adapter.py --model-type "stablelm" --input-file "..data/entity_extraction/entity-extraction-test-data.json"
209
```
210
```
211
python generate/inference_adapter.py --model-type "llama2" --input-file "..data/entity_extraction/entity-extraction-test-data.json"
212
```  
213
214
**2. Generate predictions for models fine-tuned using LoRA PEFT:** 
215
```
216
python generate/inference_lora.py --model-type "stablelm" --input-file "..data/entity_extraction/entity-extraction-test-data.json"
217
```
218
```
219
python generate/inference_lora.py --model-type "llama2" --input-file "..data/entity_extraction/entity-extraction-test-data.json"
220
```  
221
222
   
223
## Performance Evaluation and Metrics ๐Ÿ“Š  
224
225
The effectiveness of the fine-tuned models was evaluated based on their precision and recall in identifying medical entities. These metrics provided insights into the models' accuracy and reliability compared to each other when trained using different techniques and to their base versions.
226
227
The test dataset which was kept aside during the data generation process was used to evaluate the performance of the fine-tuned models. The test dataset contains 70 samples and the performance of the models was evaluated on precision and recall for the `drug_name` and `adverse_events` entities.
228
229
The evaluation script can be found in the `scripts` folder and can be run as follows:  
230
```
231
python scripts/evaluate.py
232
```  
233
234
   
235
Based on the evaluation, the following table shows the performance of the different models:
236
237
| Model Type | Training Technique | Precision | Recall |
238
| --- | :---: | :---: | :---: |
239
| Llama-2-7b-hf | Base Model | 0.00 | 0.00 |
240
| Llama-2-7b-hf | PEFT (LoRA) | 0.87 | 0.85 |
241
| **Llama-2-7b-hf** | **PEFT (Adapter)** | **0.88** | **0.89** |
242
| stablelm-base-alpha-3b  | Base Model | 0.00 | 0.00 |
243
| stablelm-base-alpha-3b  | PEFT (LoRA) | 0.81 | 0.82 |
244
| stablelm-base-alpha-3b  | PEFT (Adapter) | 0.85 | 0.83 |  
245
246
The base models were not able to identify any of the entities in the test dataset properly. Somtimes with few shot learning prompts, the base models were able to identify the entities but their results were not structured properly or parsable. The fine-tuned models on the other hand were able to identify the entities very well. The performance of the fine-tuned models was similar for both the PEFT techniques. The 7 Billion parameter `Llama-2` model fine-tuned with `PEFT (Adapter)` performed slightly better than the 3 Billion parameter Stable-LM model.  
247
248
   
249
## Future Work ๐Ÿ“ˆ  
250
251
* Potential future developments include creating a user-friendly interface or tool that leverages these fine-tuned models. Such a tool would be invaluable for pharmaceutical companies and medical professionals, enabling efficient and accurate extraction of medical entities from various reports. 
252
* We also plan to include more variety in the data, including more drugs and more side effects.
253
* We can also look into pre-training the LLMs on a larger biomedical dataset and then fine-tuning them on the real world medical adverse event reports. This will help the model learn more about the different types of medical entities and improve its performance.  
254
255
   
256
## Project Structure ๐Ÿงฌ  
257
The project structure is as follows:
258
```
259
โ”œโ”€โ”€ data                                        <- directory for project data
260
    โ”œโ”€โ”€ entity-extraction                       <- directory for processed entity extraction data
261
        โ”œโ”€โ”€ entity-extraction-data.json         <- full prepared synthetic dataset    
262
        โ”œโ”€โ”€ entity-extraction-test-data.json    <- test data for entity extraction
263
        โ”œโ”€โ”€ entity-extraction-train-data.json   <- train data for entity extraction
264
        โ”œโ”€โ”€ train.pt                            <- python pickle file for train data
265
        โ”œโ”€โ”€ test.pt                             <- python pickle file for test data
266
    โ”œโ”€โ”€ entity_extraction_reports               <- directory of generated adverse event reports
267
        โ”œโ”€โ”€ [DRUG NAME].json                    <- synthetic adverse event report for the drug
268
    โ”œโ”€โ”€ raw_drug_info                           <- directory for raw scraped drug information
269
        โ”œโ”€โ”€ [DRUG NAME].txt                     <- raw scraped drug information
270
    โ”œโ”€โ”€ predictions-llama2-adapter.json         <- predictions of the Llama-2 fine-tuned using Adapter PEFT
271
    โ”œโ”€โ”€ predictions-llama2-lora.json            <- predictions of the Llama-2 fine-tuned using LoRA PEFT
272
    โ”œโ”€โ”€ predictions-stablelm-adapter.json       <- predictions of the Stable-LM fine-tuned using Adapter PEFT
273
    โ”œโ”€โ”€ predictions-stablelm-lora.json          <- predictions of the Stable-LM fine-tuned using LoRA PEFT
274
    โ”œโ”€โ”€ prompt-template.txt                     <- prompt used to generate synthetic data
275
โ”œโ”€โ”€ finetune                                    <- directory for fine-tuning scripts
276
    โ”œโ”€โ”€ adapter_v2.py                           <- script to fine-tune LLMs using Adapter PEFT
277
    โ”œโ”€โ”€ lora.py                                 <- script to fine-tune LLMs using LoRA PEFT
278
โ”œโ”€โ”€ generate                                    <- directory for inference scripts
279
    โ”œโ”€โ”€ inference_adapter.py                    <- script to generate predictions using Adapter PEFT
280
    โ”œโ”€โ”€ inference_base.py                       <- script to generate predictions using base LLMs
281
    โ”œโ”€โ”€ inference_lora.py                       <- script to generate predictions using LoRA PEFT
282
โ”œโ”€โ”€ lit_gpt                                     <- directory for LIT-GPT Framework code
283
โ”œโ”€โ”€ notebooks                                   <- directory to store any exploration notebooks used
284
โ”œโ”€โ”€ performance_testing                         <- directory to store performance testing data
285
    โ”œโ”€โ”€ test_answers_analysis.xlsx              <- analysis of the answers returned by the pipeline
286
โ”œโ”€โ”€ scripts                                     <- directory for pipeline scripts or utility scripts
287
    โ”œโ”€โ”€ combine_data.py                         <- script to combine the generated data for all drugs
288
    โ”œโ”€โ”€ convert_hf_checkpoint.py                <- script to convert a HuggingFace checkpoint to a LIT-GPT checkpoint
289
    โ”œโ”€โ”€ data-prepare.py                         <- script to prepare the synthetic dataset for fine-tuning
290
    โ”œโ”€โ”€ download.py                             <- script to download the pre-trained LLMs from HuggingFace
291
    โ”œโ”€โ”€ evaluate.py                             <- script to evaluate the performance of the fine-tuned models
292
    โ”œโ”€โ”€ prepare_entity_extraction_data.py       <- script to tokenize the processed data and create torch datasets
293
    โ”œโ”€โ”€ scrape_drugs_data.py                    <- script to scrape drug information from Drugs.com
294
โ”œโ”€โ”€ .gitignore                                  <- git ignore file
295
โ”œโ”€โ”€ LICENSE                                     <- license file
296
โ”œโ”€โ”€ README.md                                   <- description of project and how to set up and run it
297
โ”œโ”€โ”€ requirements.txt                            <- requirements file to document dependencies
298
```
299
300
## References ๐Ÿ“š
301
302
- [Lightning-AI/lit-gpt](https://github.com/Lightning-AI/lit-gpt/tree/cf5542a166d71c0026b35428113092eb41029a8f)
303
- [HuggingFace Transformers](https://huggingface.co/transformers/)
304
- [ChatGPT API](https://platform.openai.com/docs/guides/chat)