Diff of /src/dataset.py [000000] .. [95f789]

Switch to unified view

a b/src/dataset.py
1
import numpy as np
2
import os
3
import cv2
4
import pandas as pd
5
from torch.utils.data import Dataset
6
# import jpeg4py as jpeg
7
from utils import get_windowing, window_image
8
import pydicom
9
10
IGNORE_IDS = [
11
    'ID_6431af929',
12
]
13
14
windows_range = {
15
    'brain': [40, 80],
16
    'bone': [600, 2800],
17
    'subdual': [75, 215]
18
}
19
20
LABEL_COLS = ["epidural", "intraparenchymal", "intraventricular", "subarachnoid", "subdural", "any"]
21
LABEL_COLS_WITHOUT_ANY = ["epidural", "intraparenchymal", "intraventricular", "subarachnoid", "subdural"]
22
23
24
def load_dicom_image(path):
25
    data = pydicom.read_file(path)
26
    image = data.pixel_array
27
    window_center, window_width, intercept, slope = get_windowing(data)
28
    images = []
29
    image_windowed = window_image(image, window_center, window_width, intercept, slope)
30
    images.append(image_windowed)
31
32
    for k, v in windows_range.items():
33
        image_windowed = window_image(image, v[0], v[1], intercept, slope)
34
        images.append(image_windowed)
35
36
    images = np.asarray(images).transpose((1, 2, 0))
37
    images = images / 255
38
    return images
39
40
41
def load_image(path):
42
    image = cv2.imread(path)
43
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
44
    return image
45
46
47
def load_random_windows(path, id):
48
    random_window = np.random.choice(['brain', 'bone', 'subdual'], 1)[0]
49
    return load_image(os.path.join(path, random_window, id + ".jpg"))
50
51
52
def load_multi_images(root, image_name):
53
    images = []
54
    for i, (k, v) in enumerate(windows_range.items()):
55
        image = cv2.imread(os.path.join(root, k, image_name), 0)
56
        images.append(image)
57
58
    images = np.asarray(images).transpose((1, 2, 0))
59
60
    return images
61
62
63
# def load_jpeg_image(path):
64
#     image = jpeg.JPEG(path).decode()
65
#     return image
66
67
68
import random
69
def get_balance_set(df):
70
    patients = set(df["patient_id"].unique())
71
    patients_pos = set(df[df["any"] == 1]["patient_id"].unique())
72
    patients_neg = patients - patients_pos
73
    patients_neg_balance = random.sample(patients_neg, len(patients_pos))
74
    patients_balance = patients_pos.union(patients_neg_balance)
75
76
    print(len(patients), len(patients_pos), len(patients), len(patients_balance))
77
78
    return df[df["patient_id"].isin(patients_balance)]
79
80
81
from sklearn.preprocessing import MinMaxScaler
82
meta_data_cols = [
83
    'image_position_patient_0', 'image_position_patient_1', 'image_position_patient_2',
84
    'image_orientation_patient_0', 'image_orientation_patient_2', 'image_orientation_patient_3',
85
    'image_orientation_patient_4', 'image_orientation_patient_5'
86
]
87
88
89
class RSNADataset(Dataset):
90
    """
91
    Read JPG images
92
    """
93
    def __init__(self, csv_file, root, with_any, transform, mode='train', image_type='jpg'):
94
        if isinstance(csv_file, pd.DataFrame):
95
            df = csv_file
96
        else:
97
            print(csv_file)
98
            df = pd.read_csv(csv_file)
99
        if mode == 'train':
100
            # df = df
101
            df = get_balance_set(df)
102
        if mode in ['train', 'valid']:
103
            meta_data = pd.read_csv(f"/data/df_dicom_metadata_train.csv", usecols=meta_data_cols + ['sop_instance_uid'])
104
        else:
105
            meta_data = pd.read_csv(f"/data/df_dicom_metadata_test.csv", usecols=meta_data_cols + ['sop_instance_uid'])
106
            df["sop_instance_uid"] = "ID_" + df["sop_instance_uid"]
107
        meta_data = meta_data[meta_data['sop_instance_uid'].isin(df['sop_instance_uid'])]
108
        df = df.merge(meta_data, on='sop_instance_uid', how='left')
109
        ID_col = "Image" if "Image" in df.columns else "ID" if "ID" in df.columns else "sop_instance_uid"
110
        df = df[~df[ID_col].isin(IGNORE_IDS)]
111
        self.ids = df[ID_col].values
112
        self.metadata = df[meta_data_cols].values
113
        self.with_any = with_any
114
        if with_any:
115
            self.labels = df[LABEL_COLS].values
116
        else:
117
            self.labels = df[LABEL_COLS_WITHOUT_ANY].values
118
        self.root = root
119
        self.transform = transform
120
        self.image_type = image_type
121
122
    def __len__(self):
123
        return len(self.ids)
124
125
    def __getitem__(self, idx):
126
        id = self.ids[idx]
127
        label = self.labels[idx].astype(np.float32)
128
129
        meta = self.metadata[idx].astype(np.float32)
130
131
        if not "ID" in id:
132
            id = "ID_" + id
133
134
        image = os.path.join(self.root, id + "." + self.image_type)
135
        image = load_image(image)
136
137
        if self.transform:
138
            augmented = self.transform(image=image)
139
            image = augmented['image']
140
141
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
142
143
        return {
144
            'images': image,
145
            'targets': label,
146
            'meta': meta
147
        }
148
149
150
class RSNARandomWindowDataset(RSNADataset):
151
    """
152
    Random select bone, brain and subdual during the training
153
    """
154
155
    def __getitem__(self, idx):
156
        id = self.ids[idx]
157
        label = self.labels[idx].astype(np.float32)
158
159
        image = load_random_windows(self.root, id)
160
161
        if self.transform:
162
            augmented = self.transform(image=image)
163
            image = augmented['image']
164
165
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
166
167
        return {
168
            'images': image,
169
            'targets': label
170
        }
171
172
173
class RSNADicomDataset(RSNADataset):
174
    """
175
    load dicom image directly. windows are applied on the fly.
176
    """
177
    def __init__(self, csv_file, root, with_any, transform, mode='train'):
178
        super(RSNADicomDataset, self).__init__(csv_file, root, with_any, transform, mode)
179
180
    def __len__(self):
181
        return len(self.ids)
182
183
    def __getitem__(self, idx):
184
        id = self.ids[idx]
185
        label = self.labels[idx].astype(np.float32)
186
187
        image = os.path.join(self.root, id + ".dcm")
188
        image = load_dicom_image(image)
189
190
        if self.transform:
191
            augmented = self.transform(image=image)
192
            image = augmented['image']
193
194
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
195
196
        return {
197
            'images': image,
198
            'targets': label
199
        }
200
201
202
class RSNAMultiWindowsDataset(Dataset):
203
    """
204
    Read all window images then concatinate.
205
    """
206
    def __init__(self, csv_file, root, with_any, transform):
207
        if isinstance(csv_file, pd.DataFrame):
208
            df = csv_file
209
        else:
210
            df = pd.read_csv(csv_file)
211
        ID_col = "Image" if "Image" in df.columns else "ID" if "ID" in df.columns else "sop_instance_uid"
212
        df = df[~df[ID_col].isin(IGNORE_IDS)]
213
        self.ids = df[ID_col].values
214
        self.with_any = with_any
215
        if with_any:
216
            self.labels = df[LABEL_COLS].values
217
        else:
218
            self.labels = df[LABEL_COLS_WITHOUT_ANY].values
219
        self.root = root
220
        self.transform = transform
221
222
    def __len__(self):
223
        return len(self.ids)
224
225
    def __getitem__(self, idx):
226
        id = self.ids[idx]
227
        label = self.labels[idx].astype(np.float32)
228
229
        # image = os.path.join(self.root, id + ".jpg")
230
        image = load_multi_images(self.root, id + ".jpg")
231
232
        if self.transform:
233
            augmented = self.transform(image=image)
234
            image = augmented['image']
235
236
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
237
238
        return {
239
            'images': image,
240
            'targets': label
241
        }