Diff of /src/dataset/dataset.py [000000] .. [780764]

Switch to unified view

a b/src/dataset/dataset.py
1
import logging
2
import os
3
import re
4
5
import pandas as pd
6
import torch
7
from torch.utils.data import Dataset
8
9
from src.utils import processed_data_path
10
11
12
class InstructionTuningDataset(Dataset):
13
    def __init__(self, split, source):
14
        assert split in ["train", "val", "test", "test_subset"]
15
        assert source in ["event", "note", "joint", "joint_all"]
16
        self.split = split
17
        self.source = source
18
        self.data_path = os.path.join(processed_data_path, f"mimic4")
19
        self.cohort = pd.read_csv(os.path.join(self.data_path, f"cohort_{split}.csv"))
20
21
        qa_note = pd.read_json(os.path.join(self.data_path, "qa_note.jsonl"), lines=True)
22
        qa_note = qa_note[qa_note.hadm_id.isin(self.cohort.hadm_id.unique())]
23
24
        qa_event = pd.read_json(os.path.join(self.data_path, f"qa_event.jsonl"), lines=True)
25
        qa_event = qa_event[qa_event.hadm_id.isin(self.cohort.hadm_id.unique())]
26
27
        if source == "note":
28
            qa = qa_note
29
        elif source == "event":
30
            qa = qa_event
31
        else:
32
            if source == "joint":
33
                logging.warning(f"Subsample event QA to {len(qa_note)}")
34
                qa_event = qa_event.sample(n=len(qa_note), replace=False, random_state=42)
35
            else:
36
                logging.warning(f"Use all event QA")
37
            qa = pd.concat([qa_note, qa_event], ignore_index=True)
38
39
        self.qa = qa
40
        logging.warning(f"Loaded {len(qa)} {source} QA samples for {split} on {event}")
41
42
    def _get_event_list(self, hadm_id):
43
        df = pd.read_csv(os.path.join(self.data_path, f"event_selected/event_{hadm_id}.csv"))
44
        event_list = []
45
        for i, row in df.iterrows():
46
            event_list.append((row.timestamp, row.event_type, row.event_value))
47
        return event_list
48
49
    def _get_event_emb(self, hadm_id):
50
        return torch.load(os.path.join(self.data_path, f"pt_event_selected_no_time_type/event_{hadm_id}.pt"))
51
52
    def __len__(self):
53
        return len(self.qa)
54
55
    @staticmethod
56
    def _extract_digits(event_tuple):
57
        timestamp, event_type, event_value = event_tuple
58
        try:
59
            if event_type == "patient demographics" or event_type == "patient_demographics":
60
                value_match = re.search(r"age:\s*([\d.]+)", event_value)
61
                if value_match:
62
                    value = float(value_match.group(1))
63
                else:
64
                    value = 0
65
                duration = 0
66
            elif event_type == "admission info" or event_type == "admission_info":
67
                value, duration = 0, 0
68
            elif event_type == "diagnoses_icd":
69
                value, duration = 0, 0
70
            elif event_type == "labevents":
71
                value_match = re.search(r":\s*([\d.]+)", event_value)
72
                if value_match:
73
                    value = float(value_match.group(1))
74
                else:
75
                    value = 0
76
                duration = 0
77
            elif event_type == "microbiologyevents":
78
                value, duration = 0, 0
79
            elif event_type == "prescriptions":
80
                value_match = re.search(r"prescribed dose:\s*([\d.]+)", event_value)
81
                if value_match:
82
                    value = float(value_match.group(1))
83
                else:
84
                    value = 0
85
                duration_match = re.search(r"duration:\s*([\d.]+)", event_value)
86
                if duration_match:
87
                    duration = float(duration_match.group(1))
88
                else:
89
                    duration = 0
90
            elif event_type == "transfers":
91
                value, duration = 0, 0
92
            elif event_type == "procedureevents":
93
                value = 0
94
                duration_match = re.search(r"for\s*([\d.]+)\s*hour", event_value)
95
                if duration_match:
96
                    duration = float(duration_match.group(1))
97
                else:
98
                    duration = 0
99
            else:
100
                raise ValueError(f"Unknown event type: {event_type}")
101
        except Exception as e:
102
            value, duration = 0, 0
103
            logging.warning(f"Error {e} in extracting digits from event tuple: {event_tuple}")
104
        return value, duration
105
106
    def __getitem__(self, index):
107
        data = self.qa.iloc[index]
108
        q = data["q"]
109
        a = data["a"]
110
        event_emb = self._get_event_emb(data["hadm_id"])
111
        num_events = event_emb.shape[0]
112
113
        event_list = self._get_event_list(data["hadm_id"])
114
        assert len(event_list) == num_events
115
        time_tensor = torch.tensor([[e[0]] for e in event_list], dtype=torch.float32)
116
        value_duration_tensor = torch.tensor([self._extract_digits(e) for e in event_list], dtype=torch.float32)
117
        event_emb = torch.cat(
118
            [
119
                event_emb,
120
                time_tensor,
121
                value_duration_tensor,
122
            ],
123
            dim=1
124
        )
125
        final_q = "\n".join(["<image>" * num_events, q])
126
127
        return final_q, a, event_emb