Diff of /dataset.py [000000] .. [1928b6]

Switch to unified view

a b/dataset.py
1
import numpy as np
2
import pandas as pd
3
import os
4
import glob
5
import torch
6
import torch.nn as nn
7
import random
8
import itertools
9
from sklearn.metrics import roc_auc_score
10
import torch.nn.functional as F
11
from scipy import interpolate
12
13
data_path_img = '<path-to-binary-image-files>'
14
data_file     = '<path-and_name-of-csv-datafile>'
15
image_size    = 310 # this is the size (310x310) of the MIP images used in the paper. Adjust to fit your images.
16
17
def get_datasets_singleview(transform=None, norm=None, balance=False, split_index=0):
18
    split = 'split'+str(split_index)
19
    df = pd.read_csv(data_file)
20
    # Balance weight
21
    weight_neg_pos = [1-(df.target==0).sum()/len(df), 1-(df.target==1).sum()/len(df)]    
22
    # Read split
23
    df_train = df[df[split]=='train'].drop(df.filter(regex='split').columns,axis=1)
24
    train_dset = dataset_singleview(df_train, transform=transform, norm=norm)
25
    trainval_dset = dataset_singleview_center(df_train, transform=None, norm=norm)
26
    # Val split
27
    df_val = df[df[split]=='val'].drop(df.filter(regex='split').columns,axis=1)
28
    val_dset = dataset_singleview_center(df_val, transform=None, norm=norm)
29
    # Test split
30
    df_test = df[df[split]=='test'].drop(df.filter(regex='split').columns,axis=1)
31
    test_dset = dataset_singleview_center(df_test, transform=None, norm=norm)
32
    return train_dset,trainval_dset,val_dset,test_dset,weight_neg_pos
33
34
def get_bbox(img):
35
    rows = np.any(img, axis=1)
36
    cols = np.any(img, axis=0)
37
    rmin, rmax = np.where(rows)[0][[0, -1]]
38
    cmin, cmax = np.where(cols)[0][[0, -1]]
39
    return img[rmin:rmax, cmin:cmax]
40
41
def pad2square_random(image, size):
42
    out = np.zeros((size,size))
43
    # Sample offset
44
    maxr = size - image.shape[0]
45
    maxc = size - image.shape[1]
46
    offsetc = np.random.randint(0, maxc)
47
    offsetr = np.random.randint(0, maxr)
48
    # Place image
49
    out[offsetr:offsetr+image.shape[0], offsetc:offsetc+image.shape[1]] = image
50
    return out
51
52
def pad2square_center(image, size):
53
    # Place image
54
    out = np.zeros((size,size))
55
    out[int((size-image.shape[0])/2):int((size-image.shape[0])/2)+image.shape[0],int((size-image.shape[1])/2):int((size-image.shape[1])/2)+image.shape[1]] = image
56
    return out
57
58
def clip_and_normalize_SUVimage(img):
59
    mu  = 2.13
60
    std = 3.39
61
    q   = 30.00
62
    img = np.clip(img,0.,q)
63
    return (img-mu)/std
64
65
def get_image(df, transform, norm):
66
    name = glob.glob(os.path.join(data_path_img,df.filename))
67
    if not name:
68
        print('File not found:',name)
69
    img  = np.fromfile(os.path.join(data_path_img, name[0]), dtype='float32')
70
    img  = np.reshape( img,[df.matrix_size_1, df.matrix_size_2])
71
    # Find bbox
72
    img = get_bbox(img)
73
    # Pad randomly
74
    img = pad2square_random(img, image_size)
75
    # Norm
76
    if norm:
77
        img = clip_and_normalize_SUVimage(img)
78
    # Make Tensors
79
    img = torch.FloatTensor(img).unsqueeze(0)
80
    if transform is not None:
81
        img = [transform(x) for x in img]
82
        img = torch.stack(img)
83
    return img
84
    
85
def get_image_center(df, transform, norm):
86
    name = glob.glob(os.path.join(data_path_img,df.filename))
87
    if not name:
88
        print('File not found:',name)
89
    img  = np.fromfile(os.path.join(data_path_img, name[0]), dtype='float32')
90
    img  = np.reshape( img,[df.matrix_size_1, df.matrix_size_2])
91
    # Find bbox
92
    img = get_bbox(img)
93
    # Pad randomly
94
    img = pad2square_center(img, image_size)
95
    # Norm
96
    if norm:
97
        img = clip_and_normalize_SUVimage(img)
98
    # Make Tensors
99
    img = torch.FloatTensor(img).unsqueeze(0)
100
    if transform is not None:
101
        img = [transform(x) for x in img]
102
        img = torch.stack(img)
103
    return img
104
105
class dataset_singleview(torch.utils.data.Dataset):
106
    def __init__(self, df, transform=None, norm=False):
107
        self.df = df.copy()
108
        self.transform = transform
109
        self.norm = norm
110
    
111
    def errors(self, probs):
112
        df = self.df.copy()
113
        df['p'] = probs
114
        df['pred'] = (df.p >= 0.5).astype(int)
115
        fpr = ((df.pred!=df.target) & (df.target==0)).sum() / (df.target==0).sum()
116
        fnr = ((df.pred!=df.target) & (df.target==1)).sum() / (df.target==1).sum()
117
        ber = (fpr + fnr) / 2.
118
        ## Calculate auc
119
        auc = roc_auc_score(df.target, df.p)
120
        return auc, ber, fpr, fnr
121
    
122
    def __getitem__(self, index): 
123
        df = self.df.iloc[index]
124
        # Read image
125
        img = get_image(df, self.transform, self.norm)
126
        return img, df.target
127
    
128
    def __len__(self):
129
        return len(self.df)
130
131
class dataset_singleview_center(torch.utils.data.Dataset):
132
    def __init__(self, df, transform=None, norm=False):
133
        self.df = df.copy()
134
        self.transform = transform
135
        self.norm = norm
136
    
137
    def errors(self, probs):
138
        df = self.df.copy()
139
        df['p'] = probs
140
        df['pred'] = (df.p >= 0.5).astype(int)
141
        fpr = ((df.pred!=df.target) & (df.target==0)).sum() / (df.target==0).sum()
142
        fnr = ((df.pred!=df.target) & (df.target==1)).sum() / (df.target==1).sum()
143
        ber = (fpr + fnr) / 2.
144
        ## Calculate auc
145
        auc = roc_auc_score(df.target, df.p)
146
        return auc, ber, fpr, fnr
147
    
148
    def __getitem__(self, index): 
149
        df = self.df.iloc[index]
150
        # Read image
151
        img = get_image_center(df, self.transform, self.norm)
152
        return img, df.target
153
    
154
    def __len__(self):
155
        return len(self.df)
156
157
class RandomFlip(object):
158
    """Randomly flip the 2D image.
159
    """
160
    def __call__(self, image):
161
        # Random flip: none, 0=vertical, 1=horizontal
162
        flip = random.choice((None,0,1))
163
        if flip is not None:
164
            if flip==0:
165
                image   = image[range(image.shape[flip]-1,-1,-1),:]
166
            elif flip==1:
167
                image   = image[:,range(image.shape[flip]-1,-1,-1)]       
168
        return image
169
170
class RandomFlipLeftRight(object):
171
    """Randomly flip all channels of the 2D image.
172
    """
173
    def __call__(self, image):
174
        # Random flip: none, 0=vertical, 1=horizontal
175
        flip = random.choice((None,1))
176
        if flip is not None:
177
            image = image[:,range(image.shape[1]-1,-1,-1)]
178
        return image
179
        
180
class RandomRot90(object):
181
    """Randomly rotate the 2D image by n*90 degrees.
182
    """
183
    def __call__(self, image):
184
        # Random  90 rotation
185
        rot = random.randint(0,3)
186
        if rot != 0:
187
            image = torch.rot90(image, rot, (0,1))
188
        return image
189
190
class RandomScale(object):
191
    """Randomly scale the 2D image.
192
    """
193
    def __call__(self, image):
194
        scale = np.random.uniform(low=0.85, high=1.15, size=1)
195
        image = image*scale[0]
196
        return image
197
198
class RandomNoise(object):
199
    """Randomly gauss noise the 2D image.
200
    """
201
    def __call__(self, image):
202
        noise = random.choice((None,1))
203
        if noise is not None:
204
            image[image<0] = 0
205
            level = np.random.uniform(low=0.001, high=0.02, size=1)
206
            sigma = np.random.uniform(low=0.01, high=0.1, size=1)
207
            sigma = sigma[0]*image+level[0]
208
            gauss = torch.normal(0,sigma)
209
            image = image + gauss
210
            image[image<0] = 0
211
        return image