|
a |
|
b/run.py |
|
|
1 |
import json |
|
|
2 |
import os |
|
|
3 |
from typing import Dict, List |
|
|
4 |
from pathlib import Path |
|
|
5 |
from datetime import datetime as dt |
|
|
6 |
import logging |
|
|
7 |
|
|
|
8 |
from tenacity import ( |
|
|
9 |
retry, |
|
|
10 |
stop_after_attempt, |
|
|
11 |
wait_random_exponential, |
|
|
12 |
) |
|
|
13 |
from openai import OpenAI |
|
|
14 |
import google.generativeai as genai |
|
|
15 |
from langchain.callbacks.manager import CallbackManager |
|
|
16 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
|
17 |
from langchain.llms import Ollama |
|
|
18 |
import pandas as pd |
|
|
19 |
|
|
|
20 |
from config.config import * |
|
|
21 |
from prompts.prompt import * |
|
|
22 |
|
|
|
23 |
logging.basicConfig(filename=f'logs/{dt.now().strftime("%Y%m%d")}.log', level=logging.INFO, format='%(asctime)s\n%(message)s') |
|
|
24 |
|
|
|
25 |
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) |
|
|
26 |
def query_llm( |
|
|
27 |
model: str, |
|
|
28 |
llm, |
|
|
29 |
systemPrompt: str, |
|
|
30 |
userPrompt: str, |
|
|
31 |
): |
|
|
32 |
if model in ['gpt-4-1106-preview', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo-1106']: |
|
|
33 |
try: |
|
|
34 |
result = llm.chat.completions.create( |
|
|
35 |
model=model, |
|
|
36 |
messages=[ |
|
|
37 |
{'role': 'system', 'content': systemPrompt}, |
|
|
38 |
{'role': 'user', 'content': userPrompt}, |
|
|
39 |
], |
|
|
40 |
) |
|
|
41 |
except Exception as e: |
|
|
42 |
logging.info(f'{e}') |
|
|
43 |
raise e |
|
|
44 |
return result.choices[0].message.content, result.usage.prompt_tokens, result.usage.completion_tokens |
|
|
45 |
elif model in ['gemini-pro']: |
|
|
46 |
try: |
|
|
47 |
response = llm.generate_content(systemPrompt + userPrompt) |
|
|
48 |
except Exception as e: |
|
|
49 |
logging.info(f'{e}') |
|
|
50 |
raise e |
|
|
51 |
return response.text, 0, 0 |
|
|
52 |
elif model in ['llama2:70b']: |
|
|
53 |
try: |
|
|
54 |
response = llm(systemPrompt + userPrompt) |
|
|
55 |
except Exception as e: |
|
|
56 |
logging.info(f'{e}') |
|
|
57 |
raise e |
|
|
58 |
return response, 0, 0 |
|
|
59 |
|
|
|
60 |
def format_input( |
|
|
61 |
patient: List, |
|
|
62 |
dataset: str, |
|
|
63 |
form: str, |
|
|
64 |
features: List[str], |
|
|
65 |
): |
|
|
66 |
feature_values = {} |
|
|
67 |
numerical_features = ['Diastolic blood pressure', 'Fraction inspired oxygen', 'Glucose', 'Heart Rate', 'Height', 'Mean blood pressure', 'Oxygen saturation', 'Respiratory rate', 'Systolic blood pressure', 'Temperature', 'Weight', 'pH'] |
|
|
68 |
categorical_features = ['Capillary refill rate', 'Glascow coma scale eye opening', 'Glascow coma scale motor response', 'Glascow coma scale total', 'Glascow coma scale verbal response'] |
|
|
69 |
if dataset == 'mimic-iv': |
|
|
70 |
for i, feature in enumerate(features): |
|
|
71 |
if feature in numerical_features: |
|
|
72 |
feature_values[feature] = [str(visit[i]) for visit in patient] |
|
|
73 |
for categorical_feature in categorical_features: |
|
|
74 |
indexes = [i for i, f in enumerate(features) if f.startswith(categorical_feature)] |
|
|
75 |
feature_values[categorical_feature] = [] |
|
|
76 |
for visit in patient: |
|
|
77 |
values = [visit[i] for i in indexes] |
|
|
78 |
if 1 not in values: |
|
|
79 |
feature_values[categorical_feature].append('unknown') |
|
|
80 |
else: |
|
|
81 |
for i in indexes: |
|
|
82 |
if visit[i] == 1: |
|
|
83 |
feature_values[categorical_feature].append(features[i].split('->')[-1]) |
|
|
84 |
break |
|
|
85 |
features = categorical_features + numerical_features |
|
|
86 |
elif dataset == 'tjh': |
|
|
87 |
for i, feature in enumerate(features): |
|
|
88 |
feature_values[feature] = [str(visit[i]) for visit in patient] |
|
|
89 |
|
|
|
90 |
detail = '' |
|
|
91 |
if form == 'string': |
|
|
92 |
for feature in features: |
|
|
93 |
detail += f'- {feature}: \"{", ".join(feature_values[feature])}\"\n' |
|
|
94 |
elif form == 'list': |
|
|
95 |
for feature in features: |
|
|
96 |
detail += f'- {feature}: [{", ".join(feature_values[feature])}]\n' |
|
|
97 |
elif form == 'batches': |
|
|
98 |
for i, visit in enumerate(patient): |
|
|
99 |
detail += f'Visit {i + 1}:\n' |
|
|
100 |
for feature in features: |
|
|
101 |
value = feature_values[feature][i] if i < len(feature_values[feature]) else 'unknown' |
|
|
102 |
detail += f'- {feature}: {value}\n' |
|
|
103 |
detail += '\n' |
|
|
104 |
return detail |
|
|
105 |
|
|
|
106 |
def run( |
|
|
107 |
config: Dict, |
|
|
108 |
output_logits: bool=True, |
|
|
109 |
output_prompts: bool=False, |
|
|
110 |
logits_root: str='logits', |
|
|
111 |
prompts_root: str='logs', |
|
|
112 |
): |
|
|
113 |
logging.info(f'Running config: {config}\n\n') |
|
|
114 |
|
|
|
115 |
prompt_tokens = 0 |
|
|
116 |
completion_tokens = 0 |
|
|
117 |
|
|
|
118 |
dataset = config['dataset'] |
|
|
119 |
assert dataset in ['tjh', 'mimic-iv'], f'Unknown dataset: {dataset}' |
|
|
120 |
task = config['task'] |
|
|
121 |
assert task in ['outcome', 'los', 'readmission', 'multitask'], f'Unknown task: {task}' |
|
|
122 |
time = config['time'] |
|
|
123 |
if time == 0: |
|
|
124 |
time_des = 'upon-discharge' |
|
|
125 |
elif time == 1: |
|
|
126 |
time_des = '1month' |
|
|
127 |
elif time == 2: |
|
|
128 |
time_des = '6months' |
|
|
129 |
else: |
|
|
130 |
raise ValueError(f'Unknown time: {time}') |
|
|
131 |
|
|
|
132 |
if config['unit'] is True or config['reference_range'] is True: |
|
|
133 |
unit_range = '' |
|
|
134 |
unit_values = dict(json.load(open(UNIT[dataset]))) |
|
|
135 |
range_values = dict(json.load(open(REFERENCE_RANGE[dataset]))) |
|
|
136 |
for feature in unit_values.keys(): |
|
|
137 |
unit_range += f'- {feature}: ' |
|
|
138 |
if config['unit'] is True: |
|
|
139 |
unit_range = unit_range + unit_values[feature] + ' ' |
|
|
140 |
if config['reference_range'] is True: |
|
|
141 |
unit_range = unit_range + range_values[feature] |
|
|
142 |
unit_range += '\n' |
|
|
143 |
else: |
|
|
144 |
unit_range = '' |
|
|
145 |
|
|
|
146 |
form = config['form'] |
|
|
147 |
assert form in ['string', 'batches', 'list'], f'Unknown form: {form}' |
|
|
148 |
nshot = config['n_shot'] |
|
|
149 |
if nshot == 0: |
|
|
150 |
example = '' |
|
|
151 |
elif nshot == 1: |
|
|
152 |
example = f'Here is an example of input information:\n' |
|
|
153 |
example += 'Example #1:' |
|
|
154 |
example += EXAMPLE[dataset][task][0] + '\n' |
|
|
155 |
else: |
|
|
156 |
example = f'Here are {nshot} examples of input information:\n' |
|
|
157 |
for i in range(nshot): |
|
|
158 |
example += f'Example #{i + 1}:' |
|
|
159 |
example += EXAMPLE[dataset][task][i] + '\n' |
|
|
160 |
|
|
|
161 |
if config.get('prompt_engineering') is True: |
|
|
162 |
example = COT[dataset] |
|
|
163 |
response_format = RESPONSE_FORMAT['cot'] |
|
|
164 |
else: |
|
|
165 |
response_format = RESPONSE_FORMAT[task] |
|
|
166 |
|
|
|
167 |
if task == 'outcome': |
|
|
168 |
task_description = TASK_DESCRIPTION_AND_RESPONSE_FORMAT[task][time_des] |
|
|
169 |
else: |
|
|
170 |
task_description = TASK_DESCRIPTION_AND_RESPONSE_FORMAT[task] |
|
|
171 |
|
|
|
172 |
model = config['model'] |
|
|
173 |
if model in ['gpt-4-1106-preview', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo-1106']: |
|
|
174 |
llm = OpenAI(api_key=OPENAI_API_KEY) |
|
|
175 |
elif model in ['gemini-pro']: |
|
|
176 |
genai.configure(api_key=GOOGLE_API_KEY, transport='rest') |
|
|
177 |
llm = genai.GenerativeModel(model) |
|
|
178 |
elif model in ['llama2:70b']: |
|
|
179 |
llm = Ollama(model=model) |
|
|
180 |
else: |
|
|
181 |
raise ValueError(f'Unknown model: {model}') |
|
|
182 |
|
|
|
183 |
dataset_path = f'datasets/{dataset}/processed/fold_llm' |
|
|
184 |
impute = config.get('impute', 1) |
|
|
185 |
if impute in [1, 2]: |
|
|
186 |
xs = pd.read_pickle(os.path.join(dataset_path, 'test_x.pkl')) |
|
|
187 |
else: |
|
|
188 |
xs = pd.read_pickle(os.path.join(dataset_path, 'test_x_no_impute.pkl')) |
|
|
189 |
ys = pd.read_pickle(os.path.join(dataset_path, 'test_y.pkl')) |
|
|
190 |
pids = pd.read_pickle(os.path.join(dataset_path, 'test_pid.pkl')) |
|
|
191 |
features = pd.read_pickle(os.path.join(dataset_path, 'all_features.pkl'))[2:] |
|
|
192 |
record_times = pd.read_pickle(os.path.join(dataset_path, 'test_x_record_times.pkl')) |
|
|
193 |
labels = [] |
|
|
194 |
preds = [] |
|
|
195 |
|
|
|
196 |
if output_logits: |
|
|
197 |
logits_path = os.path.join(logits_root, dataset, task, model) |
|
|
198 |
Path(logits_path).mkdir(parents=True, exist_ok=True) |
|
|
199 |
sub_dst_name = f'{form}_{str(nshot)}shot_{time_des}' |
|
|
200 |
if config['unit'] is True: |
|
|
201 |
sub_dst_name += '_unit' |
|
|
202 |
if config['reference_range'] is True: |
|
|
203 |
sub_dst_name += '_range' |
|
|
204 |
if config.get('prompt_engineering') is True: |
|
|
205 |
sub_dst_name += '_cot' |
|
|
206 |
if impute == 0: |
|
|
207 |
sub_dst_name += '_no_impute' |
|
|
208 |
elif impute == 1: |
|
|
209 |
sub_dst_name += '_impute' |
|
|
210 |
elif impute == 2: |
|
|
211 |
sub_dst_name += '_impute_info' |
|
|
212 |
sub_logits_path = os.path.join(logits_path, sub_dst_name) |
|
|
213 |
Path(sub_logits_path).mkdir(parents=True, exist_ok=True) |
|
|
214 |
if output_prompts: |
|
|
215 |
prompts_path = os.path.join(prompts_root, dataset, task, model) |
|
|
216 |
Path(prompts_path).mkdir(parents=True, exist_ok=True) |
|
|
217 |
sub_dst_name = f'{form}_{str(nshot)}shot_{time_des}' |
|
|
218 |
if config['unit'] is True: |
|
|
219 |
sub_dst_name += '_unit' |
|
|
220 |
if config['reference_range'] is True: |
|
|
221 |
sub_dst_name += '_range' |
|
|
222 |
if config.get('prompt_engineering') is True: |
|
|
223 |
sub_dst_name += '_cot' |
|
|
224 |
if impute == 0: |
|
|
225 |
sub_dst_name += '_no_impute' |
|
|
226 |
elif impute == 1: |
|
|
227 |
sub_dst_name += '_impute' |
|
|
228 |
elif impute == 2: |
|
|
229 |
sub_dst_name += '_impute_info' |
|
|
230 |
sub_prompts_path = os.path.join(prompts_path, sub_dst_name) |
|
|
231 |
Path(sub_prompts_path).mkdir(parents=True, exist_ok=True) |
|
|
232 |
|
|
|
233 |
for x, y, pid, record_time in zip(xs, ys, pids, record_times): |
|
|
234 |
if isinstance(pid, float): |
|
|
235 |
pid = str(round(pid)) |
|
|
236 |
length = len(x) |
|
|
237 |
sex = 'male' if x[0][0] == 1 else 'female' |
|
|
238 |
age = x[0][1] |
|
|
239 |
x = [visit[2:] for visit in x] |
|
|
240 |
detail = format_input( |
|
|
241 |
patient=x, |
|
|
242 |
dataset=dataset, |
|
|
243 |
form=form, |
|
|
244 |
features=features, |
|
|
245 |
) |
|
|
246 |
input_format_description = INPUT_FORMAT_DESCRIPTION[form] |
|
|
247 |
if impute == 0: |
|
|
248 |
input_format_description += MISSING_VALUE_DESCRIPTION |
|
|
249 |
elif impute == 2: |
|
|
250 |
input_format_description += INSTRUCTING_MISSING_VALUE |
|
|
251 |
userPrompt = USERPROMPT.format( |
|
|
252 |
INPUT_FORMAT_DESCRIPTION=input_format_description, |
|
|
253 |
TASK_DESCRIPTION_AND_RESPONSE_FORMAT=task_description, |
|
|
254 |
UNIT_RANGE_CONTEXT=unit_range, |
|
|
255 |
EXAMPLE=example, |
|
|
256 |
SEX=sex, |
|
|
257 |
AGE=age, |
|
|
258 |
LENGTH=length, |
|
|
259 |
RECORD_TIME_LIST=', '.join(list(map(str, record_time))), |
|
|
260 |
DETAIL=detail, |
|
|
261 |
RESPONSE_FORMAT=response_format, |
|
|
262 |
) |
|
|
263 |
if output_prompts: |
|
|
264 |
with open(os.path.join(sub_prompts_path, f'{pid}.txt'), 'w') as f: |
|
|
265 |
f.write(userPrompt) |
|
|
266 |
if output_logits: |
|
|
267 |
try: |
|
|
268 |
result, prompt_token, completion_token = query_llm( |
|
|
269 |
model=model, |
|
|
270 |
llm=llm, |
|
|
271 |
systemPrompt=SYSTEMPROMPT[dataset], |
|
|
272 |
userPrompt=userPrompt |
|
|
273 |
) |
|
|
274 |
except Exception as e: |
|
|
275 |
# logging.info(f'PatientID: {patient.iloc[0]["PatientID"]}:\n') |
|
|
276 |
logging.info(f'Query LLM Exception: {e}') |
|
|
277 |
continue |
|
|
278 |
prompt_tokens += prompt_token |
|
|
279 |
completion_tokens += completion_token |
|
|
280 |
if task == 'outcome': |
|
|
281 |
label = y[0][0] |
|
|
282 |
elif task == 'readmission': |
|
|
283 |
label = y[0][2] |
|
|
284 |
elif task == 'los': |
|
|
285 |
label = [yi[1] for yi in y] |
|
|
286 |
elif task == 'multitask': |
|
|
287 |
label = [y[0][0], y[0][2]] |
|
|
288 |
else: |
|
|
289 |
raise ValueError(f'Unknown task: {task}') |
|
|
290 |
try: |
|
|
291 |
if config.get('prompt_engineering') is True: |
|
|
292 |
pred = result |
|
|
293 |
elif task in ['los', 'multitask']: |
|
|
294 |
pred = [float(p) for p in result.split(',')] |
|
|
295 |
else: |
|
|
296 |
pred = float(result) |
|
|
297 |
except: |
|
|
298 |
if task == 'los': |
|
|
299 |
pred = [0] * len(label) |
|
|
300 |
elif task == 'multitask': |
|
|
301 |
pred = [0.501, 0.501] |
|
|
302 |
else: |
|
|
303 |
pred = 0.501 |
|
|
304 |
logging.info(f'PatientID: {pid}:\nResponse: {result}\n') |
|
|
305 |
pd.to_pickle({ |
|
|
306 |
'prompt': userPrompt, |
|
|
307 |
'pred': pred, |
|
|
308 |
'label': label, |
|
|
309 |
}, os.path.join(sub_logits_path, f'{pid}.pkl')) |
|
|
310 |
labels.append(label) |
|
|
311 |
preds.append(pred) |
|
|
312 |
if output_logits: |
|
|
313 |
logging.info(f'Prompts: {prompt_tokens}, Completions: {completion_tokens}, Total: {prompt_tokens + completion_tokens}\n\n') |
|
|
314 |
pd.to_pickle({ |
|
|
315 |
'config': config, |
|
|
316 |
'preds': preds, |
|
|
317 |
'labels': labels, |
|
|
318 |
}, os.path.join(logits_path, sub_dst_name + '.pkl')) |
|
|
319 |
|
|
|
320 |
if __name__ == '__main__': |
|
|
321 |
for config in params: |
|
|
322 |
run(config, output_logits=True, output_prompts=False) |