--- a +++ b/src/dataset/dataset.py @@ -0,0 +1,127 @@ +import logging +import os +import re + +import pandas as pd +import torch +from torch.utils.data import Dataset + +from src.utils import processed_data_path + + +class InstructionTuningDataset(Dataset): + def __init__(self, split, source): + assert split in ["train", "val", "test", "test_subset"] + assert source in ["event", "note", "joint", "joint_all"] + self.split = split + self.source = source + self.data_path = os.path.join(processed_data_path, f"mimic4") + self.cohort = pd.read_csv(os.path.join(self.data_path, f"cohort_{split}.csv")) + + qa_note = pd.read_json(os.path.join(self.data_path, "qa_note.jsonl"), lines=True) + qa_note = qa_note[qa_note.hadm_id.isin(self.cohort.hadm_id.unique())] + + qa_event = pd.read_json(os.path.join(self.data_path, f"qa_event.jsonl"), lines=True) + qa_event = qa_event[qa_event.hadm_id.isin(self.cohort.hadm_id.unique())] + + if source == "note": + qa = qa_note + elif source == "event": + qa = qa_event + else: + if source == "joint": + logging.warning(f"Subsample event QA to {len(qa_note)}") + qa_event = qa_event.sample(n=len(qa_note), replace=False, random_state=42) + else: + logging.warning(f"Use all event QA") + qa = pd.concat([qa_note, qa_event], ignore_index=True) + + self.qa = qa + logging.warning(f"Loaded {len(qa)} {source} QA samples for {split} on {event}") + + def _get_event_list(self, hadm_id): + df = pd.read_csv(os.path.join(self.data_path, f"event_selected/event_{hadm_id}.csv")) + event_list = [] + for i, row in df.iterrows(): + event_list.append((row.timestamp, row.event_type, row.event_value)) + return event_list + + def _get_event_emb(self, hadm_id): + return torch.load(os.path.join(self.data_path, f"pt_event_selected_no_time_type/event_{hadm_id}.pt")) + + def __len__(self): + return len(self.qa) + + @staticmethod + def _extract_digits(event_tuple): + timestamp, event_type, event_value = event_tuple + try: + if event_type == "patient demographics" or event_type == "patient_demographics": + value_match = re.search(r"age:\s*([\d.]+)", event_value) + if value_match: + value = float(value_match.group(1)) + else: + value = 0 + duration = 0 + elif event_type == "admission info" or event_type == "admission_info": + value, duration = 0, 0 + elif event_type == "diagnoses_icd": + value, duration = 0, 0 + elif event_type == "labevents": + value_match = re.search(r":\s*([\d.]+)", event_value) + if value_match: + value = float(value_match.group(1)) + else: + value = 0 + duration = 0 + elif event_type == "microbiologyevents": + value, duration = 0, 0 + elif event_type == "prescriptions": + value_match = re.search(r"prescribed dose:\s*([\d.]+)", event_value) + if value_match: + value = float(value_match.group(1)) + else: + value = 0 + duration_match = re.search(r"duration:\s*([\d.]+)", event_value) + if duration_match: + duration = float(duration_match.group(1)) + else: + duration = 0 + elif event_type == "transfers": + value, duration = 0, 0 + elif event_type == "procedureevents": + value = 0 + duration_match = re.search(r"for\s*([\d.]+)\s*hour", event_value) + if duration_match: + duration = float(duration_match.group(1)) + else: + duration = 0 + else: + raise ValueError(f"Unknown event type: {event_type}") + except Exception as e: + value, duration = 0, 0 + logging.warning(f"Error {e} in extracting digits from event tuple: {event_tuple}") + return value, duration + + def __getitem__(self, index): + data = self.qa.iloc[index] + q = data["q"] + a = data["a"] + event_emb = self._get_event_emb(data["hadm_id"]) + num_events = event_emb.shape[0] + + event_list = self._get_event_list(data["hadm_id"]) + assert len(event_list) == num_events + time_tensor = torch.tensor([[e[0]] for e in event_list], dtype=torch.float32) + value_duration_tensor = torch.tensor([self._extract_digits(e) for e in event_list], dtype=torch.float32) + event_emb = torch.cat( + [ + event_emb, + time_tensor, + value_duration_tensor, + ], + dim=1 + ) + final_q = "\n".join(["<image>" * num_events, q]) + + return final_q, a, event_emb