|
a |
|
b/src/dataset/dataset.py |
|
|
1 |
import logging |
|
|
2 |
import os |
|
|
3 |
import re |
|
|
4 |
|
|
|
5 |
import pandas as pd |
|
|
6 |
import torch |
|
|
7 |
from torch.utils.data import Dataset |
|
|
8 |
|
|
|
9 |
from src.utils import processed_data_path |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
class InstructionTuningDataset(Dataset): |
|
|
13 |
def __init__(self, split, source): |
|
|
14 |
assert split in ["train", "val", "test", "test_subset"] |
|
|
15 |
assert source in ["event", "note", "joint", "joint_all"] |
|
|
16 |
self.split = split |
|
|
17 |
self.source = source |
|
|
18 |
self.data_path = os.path.join(processed_data_path, f"mimic4") |
|
|
19 |
self.cohort = pd.read_csv(os.path.join(self.data_path, f"cohort_{split}.csv")) |
|
|
20 |
|
|
|
21 |
qa_note = pd.read_json(os.path.join(self.data_path, "qa_note.jsonl"), lines=True) |
|
|
22 |
qa_note = qa_note[qa_note.hadm_id.isin(self.cohort.hadm_id.unique())] |
|
|
23 |
|
|
|
24 |
qa_event = pd.read_json(os.path.join(self.data_path, f"qa_event.jsonl"), lines=True) |
|
|
25 |
qa_event = qa_event[qa_event.hadm_id.isin(self.cohort.hadm_id.unique())] |
|
|
26 |
|
|
|
27 |
if source == "note": |
|
|
28 |
qa = qa_note |
|
|
29 |
elif source == "event": |
|
|
30 |
qa = qa_event |
|
|
31 |
else: |
|
|
32 |
if source == "joint": |
|
|
33 |
logging.warning(f"Subsample event QA to {len(qa_note)}") |
|
|
34 |
qa_event = qa_event.sample(n=len(qa_note), replace=False, random_state=42) |
|
|
35 |
else: |
|
|
36 |
logging.warning(f"Use all event QA") |
|
|
37 |
qa = pd.concat([qa_note, qa_event], ignore_index=True) |
|
|
38 |
|
|
|
39 |
self.qa = qa |
|
|
40 |
logging.warning(f"Loaded {len(qa)} {source} QA samples for {split} on {event}") |
|
|
41 |
|
|
|
42 |
def _get_event_list(self, hadm_id): |
|
|
43 |
df = pd.read_csv(os.path.join(self.data_path, f"event_selected/event_{hadm_id}.csv")) |
|
|
44 |
event_list = [] |
|
|
45 |
for i, row in df.iterrows(): |
|
|
46 |
event_list.append((row.timestamp, row.event_type, row.event_value)) |
|
|
47 |
return event_list |
|
|
48 |
|
|
|
49 |
def _get_event_emb(self, hadm_id): |
|
|
50 |
return torch.load(os.path.join(self.data_path, f"pt_event_selected_no_time_type/event_{hadm_id}.pt")) |
|
|
51 |
|
|
|
52 |
def __len__(self): |
|
|
53 |
return len(self.qa) |
|
|
54 |
|
|
|
55 |
@staticmethod |
|
|
56 |
def _extract_digits(event_tuple): |
|
|
57 |
timestamp, event_type, event_value = event_tuple |
|
|
58 |
try: |
|
|
59 |
if event_type == "patient demographics" or event_type == "patient_demographics": |
|
|
60 |
value_match = re.search(r"age:\s*([\d.]+)", event_value) |
|
|
61 |
if value_match: |
|
|
62 |
value = float(value_match.group(1)) |
|
|
63 |
else: |
|
|
64 |
value = 0 |
|
|
65 |
duration = 0 |
|
|
66 |
elif event_type == "admission info" or event_type == "admission_info": |
|
|
67 |
value, duration = 0, 0 |
|
|
68 |
elif event_type == "diagnoses_icd": |
|
|
69 |
value, duration = 0, 0 |
|
|
70 |
elif event_type == "labevents": |
|
|
71 |
value_match = re.search(r":\s*([\d.]+)", event_value) |
|
|
72 |
if value_match: |
|
|
73 |
value = float(value_match.group(1)) |
|
|
74 |
else: |
|
|
75 |
value = 0 |
|
|
76 |
duration = 0 |
|
|
77 |
elif event_type == "microbiologyevents": |
|
|
78 |
value, duration = 0, 0 |
|
|
79 |
elif event_type == "prescriptions": |
|
|
80 |
value_match = re.search(r"prescribed dose:\s*([\d.]+)", event_value) |
|
|
81 |
if value_match: |
|
|
82 |
value = float(value_match.group(1)) |
|
|
83 |
else: |
|
|
84 |
value = 0 |
|
|
85 |
duration_match = re.search(r"duration:\s*([\d.]+)", event_value) |
|
|
86 |
if duration_match: |
|
|
87 |
duration = float(duration_match.group(1)) |
|
|
88 |
else: |
|
|
89 |
duration = 0 |
|
|
90 |
elif event_type == "transfers": |
|
|
91 |
value, duration = 0, 0 |
|
|
92 |
elif event_type == "procedureevents": |
|
|
93 |
value = 0 |
|
|
94 |
duration_match = re.search(r"for\s*([\d.]+)\s*hour", event_value) |
|
|
95 |
if duration_match: |
|
|
96 |
duration = float(duration_match.group(1)) |
|
|
97 |
else: |
|
|
98 |
duration = 0 |
|
|
99 |
else: |
|
|
100 |
raise ValueError(f"Unknown event type: {event_type}") |
|
|
101 |
except Exception as e: |
|
|
102 |
value, duration = 0, 0 |
|
|
103 |
logging.warning(f"Error {e} in extracting digits from event tuple: {event_tuple}") |
|
|
104 |
return value, duration |
|
|
105 |
|
|
|
106 |
def __getitem__(self, index): |
|
|
107 |
data = self.qa.iloc[index] |
|
|
108 |
q = data["q"] |
|
|
109 |
a = data["a"] |
|
|
110 |
event_emb = self._get_event_emb(data["hadm_id"]) |
|
|
111 |
num_events = event_emb.shape[0] |
|
|
112 |
|
|
|
113 |
event_list = self._get_event_list(data["hadm_id"]) |
|
|
114 |
assert len(event_list) == num_events |
|
|
115 |
time_tensor = torch.tensor([[e[0]] for e in event_list], dtype=torch.float32) |
|
|
116 |
value_duration_tensor = torch.tensor([self._extract_digits(e) for e in event_list], dtype=torch.float32) |
|
|
117 |
event_emb = torch.cat( |
|
|
118 |
[ |
|
|
119 |
event_emb, |
|
|
120 |
time_tensor, |
|
|
121 |
value_duration_tensor, |
|
|
122 |
], |
|
|
123 |
dim=1 |
|
|
124 |
) |
|
|
125 |
final_q = "\n".join(["<image>" * num_events, q]) |
|
|
126 |
|
|
|
127 |
return final_q, a, event_emb |