In [None]:
import os
import sys

src_path = os.path.abspath("../..")
print(src_path)
sys.path.append(src_path)

In [None]:
from src.utils import create_directory, raw_data_path, processed_data_path, set_seed, remote_project_path

In [None]:
set_seed(seed=42)

In [None]:
import pandas as pd

In [None]:
model_path = os.path.join(remote_project_path, "output")

In [None]:
output_path = os.path.join(processed_data_path, "mimic4")

In [None]:
cohort = pd.read_csv(os.path.join(output_path, "cohort_test_subset.csv"))
print(cohort.shape)
cohort.head()

In [None]:
hadm_ids = set(cohort.hadm_id.unique().tolist())
len(hadm_ids)

In [None]:
import logging
import os

import pandas as pd
import torch
from torch.utils.data import Dataset
import re

from src.utils import processed_data_path


class EvalInstructionTuningDataset(Dataset):
    def __init__(self):
        self.data_path = os.path.join(processed_data_path, f"mimic4")
        qa = pd.read_csv(os.path.join(self.data_path, "qa_test_subset.csv"))
        qa["source"] = qa.event_type.apply(lambda x: "note" if pd.isna(x) else "event")
        self.qa = qa
        logging.warning(f"Loaded {len(qa)} QA samples")
    
    def _get_event_list(self, hadm_id):
        df = pd.read_csv(os.path.join(self.data_path, f"event_selected/event_{hadm_id}.csv"))
        event_list = []
        for i, row in df.iterrows():
            event_list.append((row.timestamp, row.event_type, row.event_value))
        return event_list

    def _get_event_emb(self, hadm_id):
        return torch.load(os.path.join(self.data_path, f"pt_event_selected_no_time_type/event_{hadm_id}.pt"))

    def __len__(self):
        return len(self.qa)

    @staticmethod
    def _extract_digits(event_tuple):
        timestamp, event_type, event_value = event_tuple
        try:
            if event_type == "patient demographics":
                value_match = re.search(r"age:\s*([\d.]+)", event_value)
                if value_match:
                    value = float(value_match.group(1))
                else:
                    value = 0
                duration = 0
            elif event_type == "admission info":
                value, duration = 0, 0
            elif event_type == "diagnoses_icd":
                value, duration = 0, 0
            elif event_type == "labevents":
                value_match = re.search(r":\s*([\d.]+)", event_value)
                if value_match:
                    value = float(value_match.group(1))
                else:
                    value = 0
                duration = 0
            elif event_type == "microbiologyevents":
                value, duration = 0, 0
            elif event_type == "prescriptions":
                value_match = re.search(r"prescribed dose:\s*([\d.]+)", event_value)
                if value_match:
                    value = float(value_match.group(1))
                else:
                    value = 0
                duration_match = re.search(r"duration:\s*([\d.]+)", event_value)
                if duration_match:
                    duration = float(duration_match.group(1))
                else:
                    duration = 0
            elif event_type == "transfers":
                value, duration = 0, 0
            elif event_type == "procedureevents":
                value = 0
                duration_match = re.search(r"for\s*([\d.]+)\s*hour", event_value)
                if duration_match:
                    duration = float(duration_match.group(1))
                else:
                    duration = 0
            else:
                raise ValueError(f"Unknown event type: {event_type}")
        except Exception as e:
            value, duration = 0, 0
            logging.warning(f"Error {e} in extracting digits from event tuple: {event_tuple}")
        return value, duration

    def __getitem__(self, index):
        data = self.qa.iloc[index]
        q = data["q"]
        a = data["a"]
        source = data["source"]
        hadm_id = data["hadm_id"]
        event_emb = self._get_event_emb(data["hadm_id"])
        num_events = event_emb.shape[0]
        event_list = self._get_event_list(data["hadm_id"])
        assert len(event_list) == num_events
        time_tensor = torch.tensor([[e[0]] for e in event_list], dtype=torch.float32)
        value_duration_tensor = torch.tensor([self._extract_digits(e) for e in event_list], dtype=torch.float32)
        event_emb = torch.cat(
            [
                event_emb,
                time_tensor,
                value_duration_tensor,
            ],
            dim=1
        )
        final_q = "\n".join(["<image>" * num_events, q])
        return final_q, a, event_emb, source, hadm_id

In [None]:
dataset = EvalInstructionTuningDataset()
q, a, event_emb, source, hadm_id = dataset[0]
print(q)
print(a)
print(source)
print(hadm_id)
print(event_emb.shape)

In [None]:
from src.model.modeling_llemr import LlemrForConditionalGeneration
from src.model.init_llemr import init_llemr
from transformers import AutoTokenizer
from src.model.modeling_dummy import DummyModel
from peft import PeftModel

device = "cuda:0"
llm_pretrained_model_name_or_path = "lmsys/vicuna-7b-v1.5"
lora_name_or_path = "zzachw12/llemr-v1"
model, tokenizer = init_llemr(llm_pretrained_model_name_or_path, 1027)
model.to(torch.bfloat16)
model = PeftModel.from_pretrained(model, lora_name_or_path)
model.to(device)
model.eval()
sys_prompt = "You are an AI assistant specialized in analyzing ICU patient data."

In [None]:
model.dtype

In [None]:
from tqdm import tqdm


all_responses = {}
for q, a, event_emb, source, hadm_id in tqdm(dataset):
    message = [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": q},
    ]
    message = tokenizer.apply_chat_template(
        message,
        tokenize=False,
        add_generation_prompt=True
    )
    inputs = tokenizer(
        message,
        return_tensors="pt",
        padding=True,
        truncation=True,
        add_special_tokens=False,
    )
    inputs = inputs.to(device)
    event_emb = event_emb.unsqueeze(1).to(device)
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        pixel_values=event_emb,
        max_new_tokens=256
    )
    generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
    all_responses[(source, hadm_id)] = generated_text

In [None]:
print(f"Processed {len(all_responses)} responses")

In [None]:
create_directory(os.path.join(model_path, "llemr_vicuna/qa_output"))

In [None]:
import json


with open(os.path.join(model_path, "llemr_vicuna/qa_output/answer.jsonl"), "w") as file:
    for _, data in dataset.qa.iterrows():
        a_hat = all_responses.get((data.source, data.hadm_id), "")
        json_string = json.dumps({"hadm_id": data.hadm_id, "q": data.q, "a": data.a, "a_hat": a_hat, "source": data.source})
        file.write(json_string + '\n')