In [None]:
import os
import sys

src_path = os.path.abspath("../..")
print(src_path)
sys.path.append(src_path)

In [None]:
from src.utils import create_directory, raw_data_path, processed_data_path, set_seed, remote_project_path

In [None]:
set_seed(seed=42)

In [None]:
import pandas as pd

In [None]:
model_path = os.path.join(remote_project_path, "output")

In [None]:
output_path = os.path.join(processed_data_path, "mimic4")

In [None]:
cohort = pd.read_csv(os.path.join(output_path, "cohort_test_subset.csv"))
print(cohort.shape)
cohort.head()

In [None]:
hadm_ids = set(cohort.hadm_id.unique().tolist())
len(hadm_ids)

In [None]:
qa = pd.read_csv(os.path.join(output_path, "qa_test_subset.csv"))
qa["source"] = qa.event_type.apply(lambda x: "note" if pd.isna(x) else "event")
qa

In [None]:
def get_events(hadm_id):
    df = pd.read_csv(os.path.join(output_path, f"event_selected/event_{hadm_id}.csv"))  
    text = []
    for i, row in df.iterrows():
        text.append(f"{row.timestamp:.2f} hour, {row.event_type}, {row.event_value}")
    return "\n".join(text)

In [None]:
print(get_events(qa.iloc[2].hadm_id))

In [None]:
system_content = """You are an AI assistant specialized in analyzing ICU patient data.
You are given a sequence of clinical events from an ICU patient's hospital admission.
Each event is formatted as follows: {time elapsed after admission (in hours)}, {event type}, {event value}.
Based on this sequence of events, provide a concise and accurate answer to the question below.
Keep your response within 256 tokens."""

In [None]:
messages = [{"role": "system", "content": system_content},
            {"role": "user", "content": f"{qa.iloc[0].q}\n\n" + get_events(qa.iloc[0].hadm_id)}]

In [None]:
print(messages[0]["content"])

In [None]:
print(messages[1]["content"])

In [None]:
prompts = {}
for _, data in qa.iterrows():
    messages = [{"role": "system", "content": system_content},
                {"role": "user", "content": f"{data.q}\n\n" + get_events(data.hadm_id)}]
    prompts[(data.source, data.hadm_id)] = messages
len(prompts)

In [None]:
prompts[("note", qa.iloc[0].hadm_id)]

In [None]:
import tiktoken


def num_tokens_from_message(message):
    encoding = tiktoken.encoding_for_model("gpt-4")
    return len(encoding.encode(message[0]["content"])) + len(encoding.encode(message[1]["content"])) + 11    

In [None]:
num_tokens_from_message(messages)

In [None]:
prompts_num_tokens = {}
for k, v in prompts.items():
    prompts_num_tokens[k] = num_tokens_from_message(v)

In [None]:
import numpy as np


print("mean: ", np.mean(list(prompts_num_tokens.values())))
print("std: ", np.std(list(prompts_num_tokens.values())))
print("min: ", np.min(list(prompts_num_tokens.values())))
print("max: ", np.max(list(prompts_num_tokens.values())))
print("25th Quantile: ", np.percentile(list(prompts_num_tokens.values()), 25))
print("50th Quantile: ", np.percentile(list(prompts_num_tokens.values()), 50))
print("75th Quantile: ", np.percentile(list(prompts_num_tokens.values()), 75))

In [None]:
max_response_tokens = 256
token_limit = 128000

In [None]:
import copy


def trim_message(message):
    trimmed_message = copy.deepcopy(message)
    encoding = tiktoken.encoding_for_model("gpt-4")
    system_tokens = len(encoding.encode(message[0]["content"]))
    user_tokens = len(encoding.encode(message[1]["content"]))
    
    # If the total tokens are within the limit, no trimming is needed
    if system_tokens + user_tokens + 11 + max_response_tokens <= token_limit:
        return trimmed_message
    
    # Otherwise, trim the user message content
    available_tokens = token_limit - system_tokens - 11 - max_response_tokens
    trimmed_user_content = encoding.decode(encoding.encode(message[1]["content"])[:available_tokens])
    
    # Update the message with the trimmed content
    trimmed_message[1]["content"] = trimmed_user_content
    return trimmed_message

In [None]:
trimmed_prompts = {}
for k, v in prompts.items():
    trimmed_v = trim_message(v)
    if trimmed_v != v:
        print(f"{k} is trimmed")
    trimmed_prompts[k] = trim_message(v)
len(trimmed_prompts)

In [None]:
import asyncio
from openai import AsyncAzureOpenAI


# TODO: Enter your credentials
async_client = AsyncAzureOpenAI(
    azure_endpoint="",
    api_key="",
    api_version=""
)

In [None]:
async def generate_chat_response(async_client, prompt):
    chat_params = {
        "model": "gpt-4",
        "messages": prompt,
        "max_tokens": max_response_tokens,
        "temperature": 0.0,
    }
    try:
        response = await async_client.chat.completions.create(**chat_params)
    except Exception as e:
        print(f"Error in call_async: {e}")
        time.sleep(10)
        print(f"Sleep for 10s...")
        return -1
    return response.choices[0].message.content

In [None]:
import time


async def process_prompts(prompts):
    # Gather all the futures together and wait for them to complete
    responses = await asyncio.gather(*(generate_chat_response(async_client, prompt) for prompt in prompts))        
    return responses

In [None]:
def chunk_list(lst, chunk_size):
    """Yield successive chunk_size chunks from lst."""
    for i in range(0, len(lst), chunk_size):
        yield lst[i:i + chunk_size]

In [None]:
from tqdm.asyncio import tqdm


async def process_prompts_in_batches(prompts, batch_size, repeat=3):
    all_responses = {}
    
    for i in range(repeat):
        
        print(f"round {i}")
        prev_n_responses = len(all_responses)
        
        prompts_k = [k for k in prompts.keys() if k not in all_responses]

        # Chunk the prompts into batches
        prompt_k_batches = list(chunk_list(prompts_k, batch_size))

        for batch_k in tqdm(prompt_k_batches, desc="Processing Batches"):
            batch_v = [prompts[k] for k in batch_k]
            responses = await process_prompts(batch_v)
            all_responses |= {k: v for k, v in zip(batch_k, responses) if type(v) is str}
        print(f"get {len(all_responses) - prev_n_responses} new responses")
    
    return all_responses

In [None]:
# Choose an appropriate batch size
batch_size = 10  # Adjust based on your system and API limits

# Assuming we are in an async environment
responses = await process_prompts_in_batches(trimmed_prompts, batch_size)
print(f"Processed {len(responses)} responses")

In [None]:
import json


with open(os.path.join(model_path, "gpt4/qa_output/answer.jsonl"), "w") as file:
    for _, data in qa.iterrows():
        a_hat = responses.get((data.source, data.hadm_id), "")
        json_string = json.dumps({"hadm_id": data.hadm_id, "q": data.q, "a": data.a, "a_hat": a_hat, "source": data.source})
        file.write(json_string + '\n')