|
a |
|
b/gpt_vs_ftT5.ipynb |
|
|
1 |
{ |
|
|
2 |
"cells": [ |
|
|
3 |
{ |
|
|
4 |
"cell_type": "code", |
|
|
5 |
"execution_count": null, |
|
|
6 |
"metadata": {}, |
|
|
7 |
"outputs": [], |
|
|
8 |
"source": [ |
|
|
9 |
"from datasets import Dataset, DatasetDict\n", |
|
|
10 |
"import torch\n", |
|
|
11 |
"from random import randrange, sample\n", |
|
|
12 |
"from transformers import DataCollatorForSeq2Seq, T5ForConditionalGeneration\n", |
|
|
13 |
"import pandas as pd\n", |
|
|
14 |
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n", |
|
|
15 |
"from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType, PeftModel, PeftConfig\n", |
|
|
16 |
"from transformers import DataCollatorForSeq2Seq\n", |
|
|
17 |
"from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n", |
|
|
18 |
"from sklearn.preprocessing import MultiLabelBinarizer\n", |
|
|
19 |
"from sklearn.metrics import classification_report, roc_auc_score, precision_recall_fscore_support\n", |
|
|
20 |
"import json\n", |
|
|
21 |
"import argparse\n", |
|
|
22 |
"import tqdm\n", |
|
|
23 |
"import numpy as np\n", |
|
|
24 |
"import random\n", |
|
|
25 |
"import os\n", |
|
|
26 |
"from skllm.config import SKLLMConfig\n", |
|
|
27 |
"from skllm import MultiLabelZeroShotGPTClassifier" |
|
|
28 |
] |
|
|
29 |
}, |
|
|
30 |
{ |
|
|
31 |
"cell_type": "code", |
|
|
32 |
"execution_count": null, |
|
|
33 |
"metadata": {}, |
|
|
34 |
"outputs": [], |
|
|
35 |
"source": [ |
|
|
36 |
"train_data = pd.read_csv('./synthetic_data/Iteration_1.csv')\n", |
|
|
37 |
"test_data = pd.read_csv('./synthetic_data/Partial_Iteration_2_annotated.csv')" |
|
|
38 |
] |
|
|
39 |
}, |
|
|
40 |
{ |
|
|
41 |
"cell_type": "code", |
|
|
42 |
"execution_count": null, |
|
|
43 |
"metadata": {}, |
|
|
44 |
"outputs": [], |
|
|
45 |
"source": [ |
|
|
46 |
"# drop rows where column label is empty\n", |
|
|
47 |
"test_data = test_data.dropna(subset=['label'])" |
|
|
48 |
] |
|
|
49 |
}, |
|
|
50 |
{ |
|
|
51 |
"attachments": {}, |
|
|
52 |
"cell_type": "markdown", |
|
|
53 |
"metadata": {}, |
|
|
54 |
"source": [ |
|
|
55 |
"Remove legacy categories" |
|
|
56 |
] |
|
|
57 |
}, |
|
|
58 |
{ |
|
|
59 |
"cell_type": "code", |
|
|
60 |
"execution_count": null, |
|
|
61 |
"metadata": {}, |
|
|
62 |
"outputs": [], |
|
|
63 |
"source": [ |
|
|
64 |
"test_data['label'] = test_data['label'].str.split(',')\n", |
|
|
65 |
"\n", |
|
|
66 |
"# Remove 'CAREGIVER' and 'EDUCATION' from the label list\n", |
|
|
67 |
"test_data['label'] = test_data['label'].apply(lambda x: [label.strip() for label in x if label.strip() not in ['CAREGIVER', 'EDUCATION']])\n", |
|
|
68 |
"\n", |
|
|
69 |
"# Convert the label list back to a comma-separated string\n", |
|
|
70 |
"test_data['label'] = test_data['label'].apply(lambda x: ','.join(x))\n", |
|
|
71 |
"\n", |
|
|
72 |
"test_data = test_data[test_data['label'] != '']" |
|
|
73 |
] |
|
|
74 |
}, |
|
|
75 |
{ |
|
|
76 |
"cell_type": "code", |
|
|
77 |
"execution_count": null, |
|
|
78 |
"metadata": {}, |
|
|
79 |
"outputs": [], |
|
|
80 |
"source": [ |
|
|
81 |
"test_text = test_data['text'].tolist()\n", |
|
|
82 |
"test_labels = test_data['label'].tolist()\n", |
|
|
83 |
"testdf = pd.DataFrame({'text':test_text, 'SDOHlabels':test_labels})\n", |
|
|
84 |
"test_dataset = Dataset.from_pandas(testdf)" |
|
|
85 |
] |
|
|
86 |
}, |
|
|
87 |
{ |
|
|
88 |
"cell_type": "code", |
|
|
89 |
"execution_count": null, |
|
|
90 |
"metadata": {}, |
|
|
91 |
"outputs": [], |
|
|
92 |
"source": [ |
|
|
93 |
"train_text = train_data['text'].tolist()\n", |
|
|
94 |
"train_labels = train_data['label'].tolist()\n", |
|
|
95 |
"traindf = pd.DataFrame({'text':train_text, 'SDOHlabels':train_labels})\n", |
|
|
96 |
"train_dataset = Dataset.from_pandas(traindf)" |
|
|
97 |
] |
|
|
98 |
}, |
|
|
99 |
{ |
|
|
100 |
"cell_type": "code", |
|
|
101 |
"execution_count": null, |
|
|
102 |
"metadata": {}, |
|
|
103 |
"outputs": [], |
|
|
104 |
"source": [ |
|
|
105 |
"BROAD_LABELS = {'TRANSPORTATION', 'HOUSING', 'RELATIONSHIP',\n", |
|
|
106 |
" 'PARENT','EMPLOYMENT', 'SUPPORT'}" |
|
|
107 |
] |
|
|
108 |
}, |
|
|
109 |
{ |
|
|
110 |
"attachments": {}, |
|
|
111 |
"cell_type": "markdown", |
|
|
112 |
"metadata": {}, |
|
|
113 |
"source": [ |
|
|
114 |
"## Fine Tuned T5" |
|
|
115 |
] |
|
|
116 |
}, |
|
|
117 |
{ |
|
|
118 |
"cell_type": "code", |
|
|
119 |
"execution_count": null, |
|
|
120 |
"metadata": {}, |
|
|
121 |
"outputs": [], |
|
|
122 |
"source": [ |
|
|
123 |
"model_path = 'path/to/finetuned/t5'" |
|
|
124 |
] |
|
|
125 |
}, |
|
|
126 |
{ |
|
|
127 |
"cell_type": "code", |
|
|
128 |
"execution_count": null, |
|
|
129 |
"metadata": {}, |
|
|
130 |
"outputs": [], |
|
|
131 |
"source": [ |
|
|
132 |
"TOKENIZER = AutoTokenizer.from_pretrained(model_path)\n", |
|
|
133 |
"MAX_S_LEN = 100\n", |
|
|
134 |
"MAX_T_LEN = 40" |
|
|
135 |
] |
|
|
136 |
}, |
|
|
137 |
{ |
|
|
138 |
"cell_type": "code", |
|
|
139 |
"execution_count": null, |
|
|
140 |
"metadata": {}, |
|
|
141 |
"outputs": [], |
|
|
142 |
"source": [ |
|
|
143 |
"config = PeftConfig.from_pretrained(model_path)\n", |
|
|
144 |
"\n", |
|
|
145 |
"# load base LLM model and tokenizer\n", |
|
|
146 |
"reloaded_model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map={\"\":0})\n", |
|
|
147 |
"\n", |
|
|
148 |
"## tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n", |
|
|
149 |
"# Load the Lora model\n", |
|
|
150 |
"reloaded_model = PeftModel.from_pretrained(reloaded_model, model_path, device_map={\"\":0})\n", |
|
|
151 |
"reloaded_model.eval()" |
|
|
152 |
] |
|
|
153 |
}, |
|
|
154 |
{ |
|
|
155 |
"attachments": {}, |
|
|
156 |
"cell_type": "markdown", |
|
|
157 |
"metadata": {}, |
|
|
158 |
"source": [ |
|
|
159 |
"Support functions" |
|
|
160 |
] |
|
|
161 |
}, |
|
|
162 |
{ |
|
|
163 |
"cell_type": "code", |
|
|
164 |
"execution_count": null, |
|
|
165 |
"metadata": {}, |
|
|
166 |
"outputs": [], |
|
|
167 |
"source": [ |
|
|
168 |
"def postprocess_function(preds):\n", |
|
|
169 |
" \"\"\"\n", |
|
|
170 |
" Perform post-processing on the predictions.\n", |
|
|
171 |
"\n", |
|
|
172 |
" Args:\n", |
|
|
173 |
" preds (list): A list of predictions.\n", |
|
|
174 |
"\n", |
|
|
175 |
" Returns:\n", |
|
|
176 |
" list: Processed predictions with fixed labels.\n", |
|
|
177 |
"\n", |
|
|
178 |
" Examples:\n", |
|
|
179 |
" >>> preds = ['REL', 'EMPLO', 'HOUS', 'UNKNOWN']\n", |
|
|
180 |
" >>> postprocess_function(preds)\n", |
|
|
181 |
" ['RELATIONSHIP', 'EMPLOYMENT', 'HOUSING', 'UNKNOWN']\n", |
|
|
182 |
"\n", |
|
|
183 |
" >>> preds = ['NO_SD', np.nan, 'SUPP']\n", |
|
|
184 |
" >>> postprocess_function(preds)\n", |
|
|
185 |
" ['<NO_SDOH>', '<NO_SDOH>', 'SUPPORT']\n", |
|
|
186 |
" \"\"\"\n", |
|
|
187 |
" lab_fixed_dict = {\n", |
|
|
188 |
" 'REL': 'RELATIONSHIP',\n", |
|
|
189 |
" 'RELAT': 'RELATIONSHIP',\n", |
|
|
190 |
" 'EMP': 'EMPLOYMENT',\n", |
|
|
191 |
" 'EMPLO': 'EMPLOYMENT',\n", |
|
|
192 |
" 'SUPP': 'SUPPORT',\n", |
|
|
193 |
" 'HOUS': 'HOUSING',\n", |
|
|
194 |
" 'PAREN': 'PARENT',\n", |
|
|
195 |
" 'TRANSPORT': 'TRANSPORTATION',\n", |
|
|
196 |
" 'NO_SD': '<NO_SDOH>',\n", |
|
|
197 |
" np.nan: '<NO_SDOH>',\n", |
|
|
198 |
" 'NO_SDOH>': '<NO_SDOH>',\n", |
|
|
199 |
" '<NO_SDOH': '<NO_SDOH>',\n", |
|
|
200 |
" }\n", |
|
|
201 |
"\n", |
|
|
202 |
" new_preds = []\n", |
|
|
203 |
" for pred in preds:\n", |
|
|
204 |
" pred_ls = []\n", |
|
|
205 |
" pred = str(pred)\n", |
|
|
206 |
" for pp in pred.split(','):\n", |
|
|
207 |
" if pp in lab_fixed_dict.keys():\n", |
|
|
208 |
" pred_ls.append(lab_fixed_dict[pp])\n", |
|
|
209 |
" else:\n", |
|
|
210 |
" pred_ls.append(pp)\n", |
|
|
211 |
" new_preds.append(','.join(pred_ls))\n", |
|
|
212 |
"\n", |
|
|
213 |
" return new_preds\n", |
|
|
214 |
"\n", |
|
|
215 |
"def preprocess_function(sample,padding=\"max_length\"):\n", |
|
|
216 |
" # add prefix to the input for t5\n", |
|
|
217 |
" inputs = [\"summarize: \" + item for item in sample[\"text\"]]\n", |
|
|
218 |
" # tokenize inputs\n", |
|
|
219 |
" model_inputs = TOKENIZER(inputs, max_length=MAX_S_LEN, padding=padding, truncation=True)\n", |
|
|
220 |
"\n", |
|
|
221 |
" # Tokenize targets with the `text_target` keyword argument\n", |
|
|
222 |
" labels = TOKENIZER(text_target=sample[\"SDOHlabels\"], max_length=MAX_T_LEN, padding=padding, truncation=True)\n", |
|
|
223 |
"\n", |
|
|
224 |
" # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore\n", |
|
|
225 |
" # padding in the loss.\n", |
|
|
226 |
" if padding == \"max_length\":\n", |
|
|
227 |
" labels[\"input_ids\"] = [\n", |
|
|
228 |
" [(l if l != TOKENIZER.pad_token_id else -100) for l in label] for label in labels[\"input_ids\"]\n", |
|
|
229 |
" ]\n", |
|
|
230 |
" model_inputs[\"labels\"] = labels[\"input_ids\"]\n", |
|
|
231 |
" return model_inputs\n", |
|
|
232 |
"\n", |
|
|
233 |
"def normal_eval(preds, gold):\n", |
|
|
234 |
" pred_temp = [p.split(\",\") for p in preds]\n", |
|
|
235 |
" gold_list = [g.split(',') for g in gold]\n", |
|
|
236 |
"\n", |
|
|
237 |
" pred_list = []\n", |
|
|
238 |
" for labs in pred_temp:\n", |
|
|
239 |
" point_pred = [p for p in labs if p in BROAD_LABELS]\n", |
|
|
240 |
" pred_list.append(point_pred)\n", |
|
|
241 |
" mlb = MultiLabelBinarizer()\n", |
|
|
242 |
" oh_gold = mlb.fit_transform(gold_list)\n", |
|
|
243 |
" oh_pred = mlb.transform(pred_list)\n", |
|
|
244 |
"\n", |
|
|
245 |
" prec, rec, f1, _ = precision_recall_fscore_support(oh_gold, oh_pred)\n", |
|
|
246 |
" micro_f1 = precision_recall_fscore_support(oh_gold, oh_pred, average='micro')[2]\n", |
|
|
247 |
" weight_f1 = precision_recall_fscore_support(oh_gold, oh_pred, average='weighted')[2]\n", |
|
|
248 |
" macro_f1 = precision_recall_fscore_support(oh_gold, oh_pred, average='macro')[2]\n", |
|
|
249 |
"\n", |
|
|
250 |
" metrics_out = {'macro_f1':macro_f1, 'micro_f1': micro_f1, 'weighted_f1': weight_f1}\n", |
|
|
251 |
" for i, lab in enumerate(list(mlb.classes_)):\n", |
|
|
252 |
" metrics_out['precision_'+str(lab)] = prec[i]\n", |
|
|
253 |
" metrics_out['recall_'+str(lab)] = rec[i]\n", |
|
|
254 |
" metrics_out['f1_'+str(lab)] = f1[i]\n", |
|
|
255 |
" print(classification_report(oh_gold, oh_pred, target_names=mlb.classes_))\n", |
|
|
256 |
" return metrics_out\n", |
|
|
257 |
"\n", |
|
|
258 |
"def predict(dataset, model, batch_size):\n", |
|
|
259 |
" predictions, references = [], []\n", |
|
|
260 |
" batch_size = batch_size\n", |
|
|
261 |
" for i in tqdm.tqdm(range(0, len(dataset), batch_size)):\n", |
|
|
262 |
" texts = dataset[i:i+batch_size]\n", |
|
|
263 |
" input_ids = TOKENIZER(texts[\"text\"], return_tensors=\"pt\", truncation=True, padding=\"max_length\").input_ids.cuda()\n", |
|
|
264 |
" outputs = model.generate(input_ids=input_ids, do_sample=False, top_p=0.9, max_new_tokens=5, num_beams=4) #, top_p=0.9, max_new_tokens=10\n", |
|
|
265 |
" outputs = TOKENIZER.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True) \n", |
|
|
266 |
" labels = dataset[i:i+batch_size][\"SDOHlabels\"] \n", |
|
|
267 |
" predictions.extend(outputs)\n", |
|
|
268 |
" references.extend(labels)\n", |
|
|
269 |
" return predictions, references" |
|
|
270 |
] |
|
|
271 |
}, |
|
|
272 |
{ |
|
|
273 |
"cell_type": "code", |
|
|
274 |
"execution_count": null, |
|
|
275 |
"metadata": {}, |
|
|
276 |
"outputs": [], |
|
|
277 |
"source": [ |
|
|
278 |
"predictions, references = predict(test_dataset, reloaded_model, 4)\n", |
|
|
279 |
"metrics = normal_eval(predictions, references)\n", |
|
|
280 |
"print('='*30+'POST PROCESSED'+'='*30)\n", |
|
|
281 |
"processed_predictions = postprocess_function(predictions)\n", |
|
|
282 |
"processed_metrics = normal_eval(processed_predictions, references)" |
|
|
283 |
] |
|
|
284 |
}, |
|
|
285 |
{ |
|
|
286 |
"attachments": {}, |
|
|
287 |
"cell_type": "markdown", |
|
|
288 |
"metadata": {}, |
|
|
289 |
"source": [ |
|
|
290 |
"## GPT" |
|
|
291 |
] |
|
|
292 |
}, |
|
|
293 |
{ |
|
|
294 |
"cell_type": "code", |
|
|
295 |
"execution_count": null, |
|
|
296 |
"metadata": {}, |
|
|
297 |
"outputs": [], |
|
|
298 |
"source": [ |
|
|
299 |
"SKLLMConfig.set_openai_key(\"API KEY HERE\")\n", |
|
|
300 |
"SKLLMConfig.set_openai_org(\"ORGANIZATION HERE\")" |
|
|
301 |
] |
|
|
302 |
}, |
|
|
303 |
{ |
|
|
304 |
"attachments": {}, |
|
|
305 |
"cell_type": "markdown", |
|
|
306 |
"metadata": {}, |
|
|
307 |
"source": [ |
|
|
308 |
"#### Multi-Label" |
|
|
309 |
] |
|
|
310 |
}, |
|
|
311 |
{ |
|
|
312 |
"cell_type": "code", |
|
|
313 |
"execution_count": null, |
|
|
314 |
"metadata": {}, |
|
|
315 |
"outputs": [], |
|
|
316 |
"source": [ |
|
|
317 |
"clf = MultiLabelZeroShotGPTClassifier(max_labels=4)\n", |
|
|
318 |
"clf.fit(None, [BROAD_LABELS])" |
|
|
319 |
] |
|
|
320 |
}, |
|
|
321 |
{ |
|
|
322 |
"cell_type": "code", |
|
|
323 |
"execution_count": null, |
|
|
324 |
"metadata": {}, |
|
|
325 |
"outputs": [], |
|
|
326 |
"source": [ |
|
|
327 |
"labels = clf.predict(test_data['text'])\n", |
|
|
328 |
"y = [foo.split(',') for foo in test_data['label']]" |
|
|
329 |
] |
|
|
330 |
}, |
|
|
331 |
{ |
|
|
332 |
"cell_type": "code", |
|
|
333 |
"execution_count": null, |
|
|
334 |
"metadata": {}, |
|
|
335 |
"outputs": [], |
|
|
336 |
"source": [ |
|
|
337 |
"mlb2 = MultiLabelBinarizer()\n", |
|
|
338 |
"y = mlb2.fit_transform(y)\n", |
|
|
339 |
"labels = mlb2.transform(labels)" |
|
|
340 |
] |
|
|
341 |
}, |
|
|
342 |
{ |
|
|
343 |
"cell_type": "code", |
|
|
344 |
"execution_count": null, |
|
|
345 |
"metadata": {}, |
|
|
346 |
"outputs": [], |
|
|
347 |
"source": [ |
|
|
348 |
"print(classification_report(y, labels, target_names=mlb2.classes_))" |
|
|
349 |
] |
|
|
350 |
} |
|
|
351 |
], |
|
|
352 |
"metadata": { |
|
|
353 |
"kernelspec": { |
|
|
354 |
"display_name": "models2", |
|
|
355 |
"language": "python", |
|
|
356 |
"name": "python3" |
|
|
357 |
}, |
|
|
358 |
"language_info": { |
|
|
359 |
"codemirror_mode": { |
|
|
360 |
"name": "ipython", |
|
|
361 |
"version": 3 |
|
|
362 |
}, |
|
|
363 |
"file_extension": ".py", |
|
|
364 |
"mimetype": "text/x-python", |
|
|
365 |
"name": "python", |
|
|
366 |
"nbconvert_exporter": "python", |
|
|
367 |
"pygments_lexer": "ipython3", |
|
|
368 |
"version": "3.10.10" |
|
|
369 |
}, |
|
|
370 |
"orig_nbformat": 4 |
|
|
371 |
}, |
|
|
372 |
"nbformat": 4, |
|
|
373 |
"nbformat_minor": 2 |
|
|
374 |
} |