Diff of /dataset.py [000000] .. [1928b6]

Switch to side-by-side view

--- a
+++ b/dataset.py
@@ -0,0 +1,211 @@
+import numpy as np
+import pandas as pd
+import os
+import glob
+import torch
+import torch.nn as nn
+import random
+import itertools
+from sklearn.metrics import roc_auc_score
+import torch.nn.functional as F
+from scipy import interpolate
+
+data_path_img = '<path-to-binary-image-files>'
+data_file     = '<path-and_name-of-csv-datafile>'
+image_size    = 310 # this is the size (310x310) of the MIP images used in the paper. Adjust to fit your images.
+
+def get_datasets_singleview(transform=None, norm=None, balance=False, split_index=0):
+    split = 'split'+str(split_index)
+    df = pd.read_csv(data_file)
+    # Balance weight
+    weight_neg_pos = [1-(df.target==0).sum()/len(df), 1-(df.target==1).sum()/len(df)]    
+    # Read split
+    df_train = df[df[split]=='train'].drop(df.filter(regex='split').columns,axis=1)
+    train_dset = dataset_singleview(df_train, transform=transform, norm=norm)
+    trainval_dset = dataset_singleview_center(df_train, transform=None, norm=norm)
+    # Val split
+    df_val = df[df[split]=='val'].drop(df.filter(regex='split').columns,axis=1)
+    val_dset = dataset_singleview_center(df_val, transform=None, norm=norm)
+    # Test split
+    df_test = df[df[split]=='test'].drop(df.filter(regex='split').columns,axis=1)
+    test_dset = dataset_singleview_center(df_test, transform=None, norm=norm)
+    return train_dset,trainval_dset,val_dset,test_dset,weight_neg_pos
+
+def get_bbox(img):
+    rows = np.any(img, axis=1)
+    cols = np.any(img, axis=0)
+    rmin, rmax = np.where(rows)[0][[0, -1]]
+    cmin, cmax = np.where(cols)[0][[0, -1]]
+    return img[rmin:rmax, cmin:cmax]
+
+def pad2square_random(image, size):
+    out = np.zeros((size,size))
+    # Sample offset
+    maxr = size - image.shape[0]
+    maxc = size - image.shape[1]
+    offsetc = np.random.randint(0, maxc)
+    offsetr = np.random.randint(0, maxr)
+    # Place image
+    out[offsetr:offsetr+image.shape[0], offsetc:offsetc+image.shape[1]] = image
+    return out
+
+def pad2square_center(image, size):
+    # Place image
+    out = np.zeros((size,size))
+    out[int((size-image.shape[0])/2):int((size-image.shape[0])/2)+image.shape[0],int((size-image.shape[1])/2):int((size-image.shape[1])/2)+image.shape[1]] = image
+    return out
+
+def clip_and_normalize_SUVimage(img):
+    mu  = 2.13
+    std = 3.39
+    q   = 30.00
+    img = np.clip(img,0.,q)
+    return (img-mu)/std
+
+def get_image(df, transform, norm):
+    name = glob.glob(os.path.join(data_path_img,df.filename))
+    if not name:
+        print('File not found:',name)
+    img  = np.fromfile(os.path.join(data_path_img, name[0]), dtype='float32')
+    img  = np.reshape( img,[df.matrix_size_1, df.matrix_size_2])
+    # Find bbox
+    img = get_bbox(img)
+    # Pad randomly
+    img = pad2square_random(img, image_size)
+    # Norm
+    if norm:
+        img = clip_and_normalize_SUVimage(img)
+    # Make Tensors
+    img = torch.FloatTensor(img).unsqueeze(0)
+    if transform is not None:
+        img = [transform(x) for x in img]
+        img = torch.stack(img)
+    return img
+    
+def get_image_center(df, transform, norm):
+    name = glob.glob(os.path.join(data_path_img,df.filename))
+    if not name:
+        print('File not found:',name)
+    img  = np.fromfile(os.path.join(data_path_img, name[0]), dtype='float32')
+    img  = np.reshape( img,[df.matrix_size_1, df.matrix_size_2])
+    # Find bbox
+    img = get_bbox(img)
+    # Pad randomly
+    img = pad2square_center(img, image_size)
+    # Norm
+    if norm:
+        img = clip_and_normalize_SUVimage(img)
+    # Make Tensors
+    img = torch.FloatTensor(img).unsqueeze(0)
+    if transform is not None:
+        img = [transform(x) for x in img]
+        img = torch.stack(img)
+    return img
+
+class dataset_singleview(torch.utils.data.Dataset):
+    def __init__(self, df, transform=None, norm=False):
+        self.df = df.copy()
+        self.transform = transform
+        self.norm = norm
+    
+    def errors(self, probs):
+        df = self.df.copy()
+        df['p'] = probs
+        df['pred'] = (df.p >= 0.5).astype(int)
+        fpr = ((df.pred!=df.target) & (df.target==0)).sum() / (df.target==0).sum()
+        fnr = ((df.pred!=df.target) & (df.target==1)).sum() / (df.target==1).sum()
+        ber = (fpr + fnr) / 2.
+        ## Calculate auc
+        auc = roc_auc_score(df.target, df.p)
+        return auc, ber, fpr, fnr
+    
+    def __getitem__(self, index): 
+        df = self.df.iloc[index]
+        # Read image
+        img = get_image(df, self.transform, self.norm)
+        return img, df.target
+    
+    def __len__(self):
+        return len(self.df)
+
+class dataset_singleview_center(torch.utils.data.Dataset):
+    def __init__(self, df, transform=None, norm=False):
+        self.df = df.copy()
+        self.transform = transform
+        self.norm = norm
+    
+    def errors(self, probs):
+        df = self.df.copy()
+        df['p'] = probs
+        df['pred'] = (df.p >= 0.5).astype(int)
+        fpr = ((df.pred!=df.target) & (df.target==0)).sum() / (df.target==0).sum()
+        fnr = ((df.pred!=df.target) & (df.target==1)).sum() / (df.target==1).sum()
+        ber = (fpr + fnr) / 2.
+        ## Calculate auc
+        auc = roc_auc_score(df.target, df.p)
+        return auc, ber, fpr, fnr
+    
+    def __getitem__(self, index): 
+        df = self.df.iloc[index]
+        # Read image
+        img = get_image_center(df, self.transform, self.norm)
+        return img, df.target
+    
+    def __len__(self):
+        return len(self.df)
+
+class RandomFlip(object):
+    """Randomly flip the 2D image.
+    """
+    def __call__(self, image):
+        # Random flip: none, 0=vertical, 1=horizontal
+        flip = random.choice((None,0,1))
+        if flip is not None:
+            if flip==0:
+                image   = image[range(image.shape[flip]-1,-1,-1),:]
+            elif flip==1:
+                image   = image[:,range(image.shape[flip]-1,-1,-1)]       
+        return image
+
+class RandomFlipLeftRight(object):
+    """Randomly flip all channels of the 2D image.
+    """
+    def __call__(self, image):
+        # Random flip: none, 0=vertical, 1=horizontal
+        flip = random.choice((None,1))
+        if flip is not None:
+            image = image[:,range(image.shape[1]-1,-1,-1)]
+        return image
+        
+class RandomRot90(object):
+    """Randomly rotate the 2D image by n*90 degrees.
+    """
+    def __call__(self, image):
+        # Random  90 rotation
+        rot = random.randint(0,3)
+        if rot != 0:
+            image = torch.rot90(image, rot, (0,1))
+        return image
+
+class RandomScale(object):
+    """Randomly scale the 2D image.
+    """
+    def __call__(self, image):
+        scale = np.random.uniform(low=0.85, high=1.15, size=1)
+        image = image*scale[0]
+        return image
+
+class RandomNoise(object):
+    """Randomly gauss noise the 2D image.
+    """
+    def __call__(self, image):
+        noise = random.choice((None,1))
+        if noise is not None:
+            image[image<0] = 0
+            level = np.random.uniform(low=0.001, high=0.02, size=1)
+            sigma = np.random.uniform(low=0.01, high=0.1, size=1)
+            sigma = sigma[0]*image+level[0]
+            gauss = torch.normal(0,sigma)
+            image = image + gauss
+            image[image<0] = 0
+        return image