Diff of /src/dataset/dataset.py [000000] .. [780764]

Switch to side-by-side view

--- a
+++ b/src/dataset/dataset.py
@@ -0,0 +1,127 @@
+import logging
+import os
+import re
+
+import pandas as pd
+import torch
+from torch.utils.data import Dataset
+
+from src.utils import processed_data_path
+
+
+class InstructionTuningDataset(Dataset):
+    def __init__(self, split, source):
+        assert split in ["train", "val", "test", "test_subset"]
+        assert source in ["event", "note", "joint", "joint_all"]
+        self.split = split
+        self.source = source
+        self.data_path = os.path.join(processed_data_path, f"mimic4")
+        self.cohort = pd.read_csv(os.path.join(self.data_path, f"cohort_{split}.csv"))
+
+        qa_note = pd.read_json(os.path.join(self.data_path, "qa_note.jsonl"), lines=True)
+        qa_note = qa_note[qa_note.hadm_id.isin(self.cohort.hadm_id.unique())]
+
+        qa_event = pd.read_json(os.path.join(self.data_path, f"qa_event.jsonl"), lines=True)
+        qa_event = qa_event[qa_event.hadm_id.isin(self.cohort.hadm_id.unique())]
+
+        if source == "note":
+            qa = qa_note
+        elif source == "event":
+            qa = qa_event
+        else:
+            if source == "joint":
+                logging.warning(f"Subsample event QA to {len(qa_note)}")
+                qa_event = qa_event.sample(n=len(qa_note), replace=False, random_state=42)
+            else:
+                logging.warning(f"Use all event QA")
+            qa = pd.concat([qa_note, qa_event], ignore_index=True)
+
+        self.qa = qa
+        logging.warning(f"Loaded {len(qa)} {source} QA samples for {split} on {event}")
+
+    def _get_event_list(self, hadm_id):
+        df = pd.read_csv(os.path.join(self.data_path, f"event_selected/event_{hadm_id}.csv"))
+        event_list = []
+        for i, row in df.iterrows():
+            event_list.append((row.timestamp, row.event_type, row.event_value))
+        return event_list
+
+    def _get_event_emb(self, hadm_id):
+        return torch.load(os.path.join(self.data_path, f"pt_event_selected_no_time_type/event_{hadm_id}.pt"))
+
+    def __len__(self):
+        return len(self.qa)
+
+    @staticmethod
+    def _extract_digits(event_tuple):
+        timestamp, event_type, event_value = event_tuple
+        try:
+            if event_type == "patient demographics" or event_type == "patient_demographics":
+                value_match = re.search(r"age:\s*([\d.]+)", event_value)
+                if value_match:
+                    value = float(value_match.group(1))
+                else:
+                    value = 0
+                duration = 0
+            elif event_type == "admission info" or event_type == "admission_info":
+                value, duration = 0, 0
+            elif event_type == "diagnoses_icd":
+                value, duration = 0, 0
+            elif event_type == "labevents":
+                value_match = re.search(r":\s*([\d.]+)", event_value)
+                if value_match:
+                    value = float(value_match.group(1))
+                else:
+                    value = 0
+                duration = 0
+            elif event_type == "microbiologyevents":
+                value, duration = 0, 0
+            elif event_type == "prescriptions":
+                value_match = re.search(r"prescribed dose:\s*([\d.]+)", event_value)
+                if value_match:
+                    value = float(value_match.group(1))
+                else:
+                    value = 0
+                duration_match = re.search(r"duration:\s*([\d.]+)", event_value)
+                if duration_match:
+                    duration = float(duration_match.group(1))
+                else:
+                    duration = 0
+            elif event_type == "transfers":
+                value, duration = 0, 0
+            elif event_type == "procedureevents":
+                value = 0
+                duration_match = re.search(r"for\s*([\d.]+)\s*hour", event_value)
+                if duration_match:
+                    duration = float(duration_match.group(1))
+                else:
+                    duration = 0
+            else:
+                raise ValueError(f"Unknown event type: {event_type}")
+        except Exception as e:
+            value, duration = 0, 0
+            logging.warning(f"Error {e} in extracting digits from event tuple: {event_tuple}")
+        return value, duration
+
+    def __getitem__(self, index):
+        data = self.qa.iloc[index]
+        q = data["q"]
+        a = data["a"]
+        event_emb = self._get_event_emb(data["hadm_id"])
+        num_events = event_emb.shape[0]
+
+        event_list = self._get_event_list(data["hadm_id"])
+        assert len(event_list) == num_events
+        time_tensor = torch.tensor([[e[0]] for e in event_list], dtype=torch.float32)
+        value_duration_tensor = torch.tensor([self._extract_digits(e) for e in event_list], dtype=torch.float32)
+        event_emb = torch.cat(
+            [
+                event_emb,
+                time_tensor,
+                value_duration_tensor,
+            ],
+            dim=1
+        )
+        final_q = "\n".join(["<image>" * num_events, q])
+
+        return final_q, a, event_emb