Diff of /datasets.py [000000] .. [77dc1e]

Switch to unified view

a b/datasets.py
1
import numpy as np
2
import pandas as pd
3
from PIL import Image, ImageFile
4
import os
5
import torch
6
from torch.utils.data import Dataset, Sampler
7
import sys
8
sys.path.append("/home/anjum/PycharmProjects/kaggle")
9
# sys.path.append("/home/anjum/rsna_code")  # GCP
10
from rsna_intracranial_hemorrhage_detection.data_prep import linear_windowing, sigmoid_windowing
11
ImageFile.LOAD_TRUNCATED_IMAGES = True
12
13
14
data_path = "/mnt/storage_dimm2/kaggle_data/rsna-intracranial-hemorrhage-detection/"
15
# data_path = "/home/anjum/rsna_data/"  # GCP
16
17
18
class ICHDataset(Dataset):
19
    def __init__(self, dataset, phase=1, image_filter=None, transforms=None, image_folder=None, png=True):
20
        df_paths = {
21
            "train": os.path.join(data_path, "stage_1_train.csv"),
22
            "test1": os.path.join(data_path, "stage_1_sample_submission.csv"),
23
            "test2": os.path.join(data_path, "stage_2_sample_submission.csv")
24
        }
25
26
        self.png = png
27
28
        if self.png:
29
            image_dirs = {
30
                "train": os.path.join(data_path, "png", "train", image_folder),
31
                "test1": os.path.join(data_path, "png", "test_stage_1", image_folder),
32
                "test2": os.path.join(data_path, "png", "test_stage_2", image_folder)
33
            }
34
        else:
35
            image_dirs = {
36
                "train": os.path.join(data_path, "npy", "train", image_folder),
37
                "test1": os.path.join(data_path, "npy", "test_stage_1", image_folder),
38
                "test2": os.path.join(data_path, "npy", "test_stage_2", image_folder)
39
            }
40
41
        self.dataset = dataset
42
        self.phase = phase
43
        self.transforms = transforms
44
        self.image_dir = image_dirs[dataset]
45
46
        self.df = pd.read_csv(df_paths[dataset]).drop_duplicates()
47
        self.df['ImageID'] = self.df['ID'].str.slice(stop=12)
48
        self.df['Diagnosis'] = self.df['ID'].str.slice(start=13)
49
        self.df_pivot = self.df.pivot(index="ImageID", columns="Diagnosis", values="Label")
50
51
        if image_filter is not None:
52
            self.df_pivot = self.df_pivot.loc[image_filter]
53
54
        if self.phase == 0:
55
            self.labels = self.df_pivot.values
56
        elif self.phase == 1:
57
            self.labels = self.df_pivot["any"].values.reshape(-1, 1)
58
        else:
59
            self.labels = self.df_pivot[["epidural", "intraparenchymal",
60
                                         "intraventricular", "subarachnoid", "subdural"]].values
61
62
        self.image_ids = self.df_pivot.index.values
63
        self.class_weights = np.mean(self.labels, axis=0)
64
65
    def load_image(self, image_name):
66
        # window_width, window_length = 80, 40  # Brain window
67
        window_width, window_length = 200, 80  # Subdural window
68
        # window_width, window_length = 130, 50  # Subdural window
69
70
        if self.png:
71
            img = np.array(Image.open(os.path.join(self.image_dir, image_name+".png")).convert("RGB"))
72
            return linear_windowing(img, window_width, window_length)
73
            # return sigmoid_windowing(img, window_width, window_length)
74
        else:
75
            img = np.load(os.path.join(self.image_dir, image_name+".npy"))
76
77
            # PIL doesn't work with 16-bit RGB images :(
78
            # Should be ok though since the useful HU interval is between 0-255
79
            if img.shape[0] == 0 or img.shape[1] == 0:
80
                return np.zeros(shape=(512, 512, 3), dtype=np.uint8)
81
            else:
82
                # return np.clip(img, 0, 255).astype(np.uint8)  # Use this with the Windowing module
83
                return linear_windowing(img, window_width, window_length)
84
                # return sigmoid_windowing(img, window_width, window_length)
85
86
    def __getitem__(self, idx):
87
        img_id = self.image_ids[idx]
88
        img = self.load_image(img_id)
89
90
        if self.transforms is not None:
91
            img = self.transforms(img)
92
93
        if self.dataset == "train":
94
            return img, torch.tensor(self.labels[idx], dtype=torch.float32)
95
        else:
96
            return img, torch.tensor([0], dtype=torch.float32)
97
98
    def __len__(self):
99
        return len(self.image_ids)
100
101
102
class BalancedRandomSampler(Sampler):
103
    def __init__(self, data_source):
104
        """
105
        Balances the negative and positive samples. All of the positive samples are used, but a random subset of
106
        the negative samples are used to create a 50:50 dataset
107
        :param data_source: An ICHDataset
108
        """
109
        super().__init__(data_source)
110
        self.labels = data_source.labels
111
        self.ids_pos = np.where(self.labels[:, 0] == 1)[0]
112
        self.ids_neg = np.where(self.labels[:, 0] == 0)[0]
113
114
    def __iter__(self):
115
        ids_neg_sampled = np.random.choice(self.ids_neg, self.ids_pos.shape[0], replace=False)
116
        ids = np.concatenate([self.ids_pos, ids_neg_sampled])
117
        np.random.shuffle(ids)
118
        return iter(ids)
119
120
    def __len__(self):
121
        return self.ids_pos.shape[0] * 2