a b/run.py
1
import json
2
import os
3
from typing import Dict, List
4
from pathlib import Path
5
from datetime import datetime as dt
6
import logging
7
8
from tenacity import (
9
    retry,
10
    stop_after_attempt,
11
    wait_random_exponential,
12
)
13
from openai import OpenAI
14
import google.generativeai as genai
15
from langchain.callbacks.manager import CallbackManager
16
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
17
from langchain.llms import Ollama
18
import pandas as pd
19
20
from config.config import *
21
from prompts.prompt import *
22
23
logging.basicConfig(filename=f'logs/{dt.now().strftime("%Y%m%d")}.log', level=logging.INFO, format='%(asctime)s\n%(message)s')
24
25
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
26
def query_llm(
27
    model: str,
28
    llm,
29
    systemPrompt: str,
30
    userPrompt: str,
31
):
32
    if model in ['gpt-4-1106-preview', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo-1106']:
33
        try:
34
            result = llm.chat.completions.create(
35
                model=model,
36
                messages=[
37
                    {'role': 'system', 'content': systemPrompt},
38
                    {'role': 'user', 'content': userPrompt},
39
                ],
40
            )
41
        except Exception as e:
42
            logging.info(f'{e}')
43
            raise e
44
        return result.choices[0].message.content, result.usage.prompt_tokens, result.usage.completion_tokens
45
    elif model in ['gemini-pro']:
46
        try:
47
            response = llm.generate_content(systemPrompt + userPrompt)
48
        except Exception as e:
49
            logging.info(f'{e}')
50
            raise e
51
        return response.text, 0, 0
52
    elif model in ['llama2:70b']:
53
        try:
54
            response = llm(systemPrompt + userPrompt)
55
        except Exception as e:
56
            logging.info(f'{e}')
57
            raise e
58
        return response, 0, 0
59
60
def format_input(
61
    patient: List,
62
    dataset: str,
63
    form: str,
64
    features: List[str],
65
):
66
    feature_values = {}
67
    numerical_features = ['Diastolic blood pressure', 'Fraction inspired oxygen', 'Glucose', 'Heart Rate', 'Height', 'Mean blood pressure', 'Oxygen saturation', 'Respiratory rate', 'Systolic blood pressure', 'Temperature', 'Weight', 'pH']
68
    categorical_features = ['Capillary refill rate', 'Glascow coma scale eye opening', 'Glascow coma scale motor response', 'Glascow coma scale total', 'Glascow coma scale verbal response']
69
    if dataset == 'mimic-iv':
70
        for i, feature in enumerate(features):
71
            if feature in numerical_features:
72
                feature_values[feature] = [str(visit[i]) for visit in patient]
73
        for categorical_feature in categorical_features:
74
            indexes = [i for i, f in enumerate(features) if f.startswith(categorical_feature)]
75
            feature_values[categorical_feature] = []
76
            for visit in patient:
77
                values = [visit[i] for i in indexes]
78
                if 1 not in values:
79
                    feature_values[categorical_feature].append('unknown')
80
                else:
81
                    for i in indexes:
82
                        if visit[i] == 1:
83
                            feature_values[categorical_feature].append(features[i].split('->')[-1])
84
                            break
85
        features = categorical_features + numerical_features
86
    elif dataset == 'tjh':
87
        for i, feature in enumerate(features):
88
            feature_values[feature] = [str(visit[i]) for visit in patient]
89
90
    detail = ''
91
    if form == 'string':
92
        for feature in features:
93
            detail += f'- {feature}: \"{", ".join(feature_values[feature])}\"\n'
94
    elif form == 'list':
95
        for feature in features:
96
            detail += f'- {feature}: [{", ".join(feature_values[feature])}]\n'
97
    elif form == 'batches':
98
        for i, visit in enumerate(patient):
99
            detail += f'Visit {i + 1}:\n'
100
            for feature in features:
101
                value = feature_values[feature][i] if i < len(feature_values[feature]) else 'unknown'
102
                detail += f'- {feature}: {value}\n'
103
            detail += '\n'
104
    return detail
105
106
def run(
107
    config: Dict,
108
    output_logits: bool=True,
109
    output_prompts: bool=False,
110
    logits_root: str='logits',
111
    prompts_root: str='logs',
112
):
113
    logging.info(f'Running config: {config}\n\n')
114
    
115
    prompt_tokens = 0
116
    completion_tokens = 0
117
    
118
    dataset = config['dataset']
119
    assert dataset in ['tjh', 'mimic-iv'], f'Unknown dataset: {dataset}'
120
    task = config['task']
121
    assert task in ['outcome', 'los', 'readmission', 'multitask'], f'Unknown task: {task}'
122
    time = config['time']
123
    if time == 0:
124
        time_des = 'upon-discharge'
125
    elif time == 1:
126
        time_des = '1month'
127
    elif time == 2:
128
        time_des = '6months'
129
    else:
130
        raise ValueError(f'Unknown time: {time}')
131
    
132
    if config['unit'] is True or config['reference_range'] is True:
133
        unit_range = ''
134
        unit_values = dict(json.load(open(UNIT[dataset])))
135
        range_values = dict(json.load(open(REFERENCE_RANGE[dataset])))
136
        for feature in unit_values.keys():
137
            unit_range += f'- {feature}: '
138
            if config['unit'] is True:
139
                unit_range = unit_range + unit_values[feature] + ' '
140
            if config['reference_range'] is True:
141
                unit_range = unit_range + range_values[feature]
142
            unit_range += '\n'
143
    else:
144
        unit_range = ''
145
        
146
    form = config['form']
147
    assert form in ['string', 'batches', 'list'], f'Unknown form: {form}'
148
    nshot = config['n_shot']
149
    if nshot == 0:
150
        example = ''
151
    elif nshot == 1:
152
        example = f'Here is an example of input information:\n'
153
        example += 'Example #1:'
154
        example += EXAMPLE[dataset][task][0] + '\n'
155
    else:
156
        example = f'Here are {nshot} examples of input information:\n'
157
        for i in range(nshot):
158
            example += f'Example #{i + 1}:'
159
            example += EXAMPLE[dataset][task][i] + '\n'
160
            
161
    if config.get('prompt_engineering') is True:
162
        example = COT[dataset]
163
        response_format = RESPONSE_FORMAT['cot']
164
    else:
165
        response_format = RESPONSE_FORMAT[task]
166
        
167
    if task == 'outcome':
168
        task_description = TASK_DESCRIPTION_AND_RESPONSE_FORMAT[task][time_des]
169
    else:
170
        task_description = TASK_DESCRIPTION_AND_RESPONSE_FORMAT[task]
171
    
172
    model = config['model']
173
    if model in ['gpt-4-1106-preview', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo-1106']:
174
        llm = OpenAI(api_key=OPENAI_API_KEY)
175
    elif model in ['gemini-pro']:
176
        genai.configure(api_key=GOOGLE_API_KEY, transport='rest')
177
        llm = genai.GenerativeModel(model)
178
    elif model in ['llama2:70b']:
179
        llm = Ollama(model=model)
180
    else:
181
        raise ValueError(f'Unknown model: {model}')
182
    
183
    dataset_path = f'datasets/{dataset}/processed/fold_llm'
184
    impute = config.get('impute', 1)
185
    if impute in [1, 2]:
186
        xs = pd.read_pickle(os.path.join(dataset_path, 'test_x.pkl'))
187
    else:
188
        xs = pd.read_pickle(os.path.join(dataset_path, 'test_x_no_impute.pkl'))
189
    ys = pd.read_pickle(os.path.join(dataset_path, 'test_y.pkl'))
190
    pids = pd.read_pickle(os.path.join(dataset_path, 'test_pid.pkl'))
191
    features = pd.read_pickle(os.path.join(dataset_path, 'all_features.pkl'))[2:]
192
    record_times = pd.read_pickle(os.path.join(dataset_path, 'test_x_record_times.pkl'))
193
    labels = []
194
    preds = []
195
    
196
    if output_logits:
197
        logits_path = os.path.join(logits_root, dataset, task, model)
198
        Path(logits_path).mkdir(parents=True, exist_ok=True)
199
        sub_dst_name = f'{form}_{str(nshot)}shot_{time_des}'
200
        if config['unit'] is True:
201
            sub_dst_name += '_unit'
202
        if config['reference_range'] is True:
203
            sub_dst_name += '_range'
204
        if config.get('prompt_engineering') is True:
205
            sub_dst_name += '_cot'
206
        if impute == 0:
207
            sub_dst_name += '_no_impute'
208
        elif impute == 1:
209
            sub_dst_name += '_impute'
210
        elif impute == 2:
211
            sub_dst_name += '_impute_info'
212
        sub_logits_path = os.path.join(logits_path, sub_dst_name)
213
        Path(sub_logits_path).mkdir(parents=True, exist_ok=True)
214
    if output_prompts:
215
        prompts_path = os.path.join(prompts_root, dataset, task, model)
216
        Path(prompts_path).mkdir(parents=True, exist_ok=True)
217
        sub_dst_name = f'{form}_{str(nshot)}shot_{time_des}'
218
        if config['unit'] is True:
219
            sub_dst_name += '_unit'
220
        if config['reference_range'] is True:
221
            sub_dst_name += '_range'
222
        if config.get('prompt_engineering') is True:
223
            sub_dst_name += '_cot'
224
        if impute == 0:
225
            sub_dst_name += '_no_impute'
226
        elif impute == 1:
227
            sub_dst_name += '_impute'
228
        elif impute == 2:
229
            sub_dst_name += '_impute_info'
230
        sub_prompts_path = os.path.join(prompts_path, sub_dst_name)
231
        Path(sub_prompts_path).mkdir(parents=True, exist_ok=True)
232
233
    for x, y, pid, record_time in zip(xs, ys, pids, record_times):
234
        if isinstance(pid, float):
235
            pid = str(round(pid))
236
        length = len(x)
237
        sex = 'male' if x[0][0] == 1 else 'female'
238
        age = x[0][1]
239
        x = [visit[2:] for visit in x]
240
        detail = format_input(
241
            patient=x,
242
            dataset=dataset,
243
            form=form,
244
            features=features,
245
        )
246
        input_format_description = INPUT_FORMAT_DESCRIPTION[form]
247
        if impute == 0:
248
            input_format_description += MISSING_VALUE_DESCRIPTION
249
        elif impute == 2:
250
            input_format_description += INSTRUCTING_MISSING_VALUE
251
        userPrompt = USERPROMPT.format(
252
            INPUT_FORMAT_DESCRIPTION=input_format_description,
253
            TASK_DESCRIPTION_AND_RESPONSE_FORMAT=task_description,
254
            UNIT_RANGE_CONTEXT=unit_range,
255
            EXAMPLE=example,
256
            SEX=sex,
257
            AGE=age,
258
            LENGTH=length,
259
            RECORD_TIME_LIST=', '.join(list(map(str, record_time))),
260
            DETAIL=detail,
261
            RESPONSE_FORMAT=response_format,
262
        )
263
        if output_prompts:
264
            with open(os.path.join(sub_prompts_path, f'{pid}.txt'), 'w') as f:
265
                f.write(userPrompt)
266
        if output_logits:
267
            try:
268
                result, prompt_token, completion_token = query_llm(
269
                    model=model,
270
                    llm=llm,
271
                    systemPrompt=SYSTEMPROMPT[dataset],
272
                    userPrompt=userPrompt
273
                )
274
            except Exception as e:
275
                # logging.info(f'PatientID: {patient.iloc[0]["PatientID"]}:\n')
276
                logging.info(f'Query LLM Exception: {e}')
277
                continue
278
            prompt_tokens += prompt_token
279
            completion_tokens += completion_token
280
            if task == 'outcome':
281
                label = y[0][0]
282
            elif task == 'readmission':
283
                label = y[0][2]
284
            elif task == 'los':
285
                label = [yi[1] for yi in y]
286
            elif task == 'multitask':
287
                label = [y[0][0], y[0][2]]
288
            else:
289
                raise ValueError(f'Unknown task: {task}')
290
            try:
291
                if config.get('prompt_engineering') is True:
292
                    pred = result
293
                elif task in ['los', 'multitask']:
294
                    pred = [float(p) for p in result.split(',')]
295
                else:
296
                    pred = float(result)
297
            except:
298
                if task == 'los':
299
                    pred = [0] * len(label)
300
                elif task == 'multitask':
301
                    pred = [0.501, 0.501]
302
                else:
303
                    pred = 0.501
304
                logging.info(f'PatientID: {pid}:\nResponse: {result}\n')
305
            pd.to_pickle({
306
                'prompt': userPrompt,
307
                'pred': pred,
308
                'label': label,
309
            }, os.path.join(sub_logits_path, f'{pid}.pkl'))
310
            labels.append(label)
311
            preds.append(pred)
312
    if output_logits:
313
        logging.info(f'Prompts: {prompt_tokens}, Completions: {completion_tokens}, Total: {prompt_tokens + completion_tokens}\n\n')    
314
        pd.to_pickle({
315
            'config': config,
316
            'preds': preds,
317
            'labels': labels,
318
        }, os.path.join(logits_path, sub_dst_name + '.pkl'))
319
320
if __name__ == '__main__':
321
    for config in params:
322
        run(config, output_logits=True, output_prompts=False)