--- a +++ b/src/dataset.py @@ -0,0 +1,241 @@ +import numpy as np +import os +import cv2 +import pandas as pd +from torch.utils.data import Dataset +# import jpeg4py as jpeg +from utils import get_windowing, window_image +import pydicom + +IGNORE_IDS = [ + 'ID_6431af929', +] + +windows_range = { + 'brain': [40, 80], + 'bone': [600, 2800], + 'subdual': [75, 215] +} + +LABEL_COLS = ["epidural", "intraparenchymal", "intraventricular", "subarachnoid", "subdural", "any"] +LABEL_COLS_WITHOUT_ANY = ["epidural", "intraparenchymal", "intraventricular", "subarachnoid", "subdural"] + + +def load_dicom_image(path): + data = pydicom.read_file(path) + image = data.pixel_array + window_center, window_width, intercept, slope = get_windowing(data) + images = [] + image_windowed = window_image(image, window_center, window_width, intercept, slope) + images.append(image_windowed) + + for k, v in windows_range.items(): + image_windowed = window_image(image, v[0], v[1], intercept, slope) + images.append(image_windowed) + + images = np.asarray(images).transpose((1, 2, 0)) + images = images / 255 + return images + + +def load_image(path): + image = cv2.imread(path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image + + +def load_random_windows(path, id): + random_window = np.random.choice(['brain', 'bone', 'subdual'], 1)[0] + return load_image(os.path.join(path, random_window, id + ".jpg")) + + +def load_multi_images(root, image_name): + images = [] + for i, (k, v) in enumerate(windows_range.items()): + image = cv2.imread(os.path.join(root, k, image_name), 0) + images.append(image) + + images = np.asarray(images).transpose((1, 2, 0)) + + return images + + +# def load_jpeg_image(path): +# image = jpeg.JPEG(path).decode() +# return image + + +import random +def get_balance_set(df): + patients = set(df["patient_id"].unique()) + patients_pos = set(df[df["any"] == 1]["patient_id"].unique()) + patients_neg = patients - patients_pos + patients_neg_balance = random.sample(patients_neg, len(patients_pos)) + patients_balance = patients_pos.union(patients_neg_balance) + + print(len(patients), len(patients_pos), len(patients), len(patients_balance)) + + return df[df["patient_id"].isin(patients_balance)] + + +from sklearn.preprocessing import MinMaxScaler +meta_data_cols = [ + 'image_position_patient_0', 'image_position_patient_1', 'image_position_patient_2', + 'image_orientation_patient_0', 'image_orientation_patient_2', 'image_orientation_patient_3', + 'image_orientation_patient_4', 'image_orientation_patient_5' +] + + +class RSNADataset(Dataset): + """ + Read JPG images + """ + def __init__(self, csv_file, root, with_any, transform, mode='train', image_type='jpg'): + if isinstance(csv_file, pd.DataFrame): + df = csv_file + else: + print(csv_file) + df = pd.read_csv(csv_file) + if mode == 'train': + # df = df + df = get_balance_set(df) + if mode in ['train', 'valid']: + meta_data = pd.read_csv(f"/data/df_dicom_metadata_train.csv", usecols=meta_data_cols + ['sop_instance_uid']) + else: + meta_data = pd.read_csv(f"/data/df_dicom_metadata_test.csv", usecols=meta_data_cols + ['sop_instance_uid']) + df["sop_instance_uid"] = "ID_" + df["sop_instance_uid"] + meta_data = meta_data[meta_data['sop_instance_uid'].isin(df['sop_instance_uid'])] + df = df.merge(meta_data, on='sop_instance_uid', how='left') + ID_col = "Image" if "Image" in df.columns else "ID" if "ID" in df.columns else "sop_instance_uid" + df = df[~df[ID_col].isin(IGNORE_IDS)] + self.ids = df[ID_col].values + self.metadata = df[meta_data_cols].values + self.with_any = with_any + if with_any: + self.labels = df[LABEL_COLS].values + else: + self.labels = df[LABEL_COLS_WITHOUT_ANY].values + self.root = root + self.transform = transform + self.image_type = image_type + + def __len__(self): + return len(self.ids) + + def __getitem__(self, idx): + id = self.ids[idx] + label = self.labels[idx].astype(np.float32) + + meta = self.metadata[idx].astype(np.float32) + + if not "ID" in id: + id = "ID_" + id + + image = os.path.join(self.root, id + "." + self.image_type) + image = load_image(image) + + if self.transform: + augmented = self.transform(image=image) + image = augmented['image'] + + image = np.transpose(image, (2, 0, 1)).astype(np.float32) + + return { + 'images': image, + 'targets': label, + 'meta': meta + } + + +class RSNARandomWindowDataset(RSNADataset): + """ + Random select bone, brain and subdual during the training + """ + + def __getitem__(self, idx): + id = self.ids[idx] + label = self.labels[idx].astype(np.float32) + + image = load_random_windows(self.root, id) + + if self.transform: + augmented = self.transform(image=image) + image = augmented['image'] + + image = np.transpose(image, (2, 0, 1)).astype(np.float32) + + return { + 'images': image, + 'targets': label + } + + +class RSNADicomDataset(RSNADataset): + """ + load dicom image directly. windows are applied on the fly. + """ + def __init__(self, csv_file, root, with_any, transform, mode='train'): + super(RSNADicomDataset, self).__init__(csv_file, root, with_any, transform, mode) + + def __len__(self): + return len(self.ids) + + def __getitem__(self, idx): + id = self.ids[idx] + label = self.labels[idx].astype(np.float32) + + image = os.path.join(self.root, id + ".dcm") + image = load_dicom_image(image) + + if self.transform: + augmented = self.transform(image=image) + image = augmented['image'] + + image = np.transpose(image, (2, 0, 1)).astype(np.float32) + + return { + 'images': image, + 'targets': label + } + + +class RSNAMultiWindowsDataset(Dataset): + """ + Read all window images then concatinate. + """ + def __init__(self, csv_file, root, with_any, transform): + if isinstance(csv_file, pd.DataFrame): + df = csv_file + else: + df = pd.read_csv(csv_file) + ID_col = "Image" if "Image" in df.columns else "ID" if "ID" in df.columns else "sop_instance_uid" + df = df[~df[ID_col].isin(IGNORE_IDS)] + self.ids = df[ID_col].values + self.with_any = with_any + if with_any: + self.labels = df[LABEL_COLS].values + else: + self.labels = df[LABEL_COLS_WITHOUT_ANY].values + self.root = root + self.transform = transform + + def __len__(self): + return len(self.ids) + + def __getitem__(self, idx): + id = self.ids[idx] + label = self.labels[idx].astype(np.float32) + + # image = os.path.join(self.root, id + ".jpg") + image = load_multi_images(self.root, id + ".jpg") + + if self.transform: + augmented = self.transform(image=image) + image = augmented['image'] + + image = np.transpose(image, (2, 0, 1)).astype(np.float32) + + return { + 'images': image, + 'targets': label + }