Diff of /3DNet/datasets/RSNA19.py [000000] .. [9f60b7]

Switch to side-by-side view

--- a
+++ b/3DNet/datasets/RSNA19.py
@@ -0,0 +1,205 @@
+'''
+Dataset for training
+Written by Whalechen
+'''
+
+import math
+import os
+import random
+
+import numpy as np
+from torch.utils.data import Dataset
+import nibabel
+from scipy import ndimage
+
+class RSNA19Dataset(Dataset):
+
+    def __init__(self, root_dir, img_list, sets):
+        with open(img_list, 'r') as f:
+            self.img_list = [line.strip() for line in f]
+
+        print(self.img_list)
+        print("Processing {} datas".format(len(self.img_list)))
+
+        self.root_dir = root_dir
+        self.input_D = sets.input_D
+        self.input_H = sets.input_H
+        self.input_W = sets.input_W
+        self.phase = sets.phase
+
+    def __nii2tensorarray__(self, data):
+        [z, y, x] = data.shape
+        new_data = np.reshape(data, [1, z, y, x])
+        new_data = new_data.astype("float32")
+            
+        return new_data
+    
+    def __len__(self):
+        return len(self.img_list)
+
+    def __getitem__(self, idx):
+        
+        if self.phase == "train":
+            # read image and labels
+            ith_info = self.img_list[idx].split(" ")
+            img_name = os.path.join(self.root_dir, ith_info[0])
+            label_name = os.path.join(self.root_dir, ith_info[1])
+
+            assert os.path.isfile(img_name)
+            assert os.path.isfile(label_name)
+            img = nibabel.load(img_name)  # We have transposed the data from WHD format to DHW
+            assert img is not None
+            mask = nibabel.load(label_name)
+            assert mask is not None
+            
+            # data processing
+            img_array, mask_array = self.__training_data_process__(img, mask)
+
+            # 2 tensor array
+            img_array = self.__nii2tensorarray__(img_array)
+            mask_array = self.__nii2tensorarray__(mask_array)
+
+            assert img_array.shape ==  mask_array.shape, "img shape:{} is not equal to mask shape:{}".format(img_array.shape, mask_array.shape)
+            return img_array, mask_array
+        
+        elif self.phase == "test":
+            # read image
+            ith_info = self.img_list[idx].split(" ")
+            img_name = os.path.join(self.root_dir, ith_info[0])
+            print(img_name)
+            assert os.path.isfile(img_name)
+            img = nibabel.load(img_name)
+            assert img is not None
+
+            # data processing
+            img_array = self.__testing_data_process__(img)
+
+            # 2 tensor array
+            img_array = self.__nii2tensorarray__(img_array)
+
+            return img_array
+            
+
+    def __drop_invalid_range__(self, volume, label=None):
+        """
+        Cut off the invalid area
+        """
+        zero_value = volume[0, 0, 0]
+        non_zeros_idx = np.where(volume != zero_value)
+        
+        [max_z, max_h, max_w] = np.max(np.array(non_zeros_idx), axis=1)
+        [min_z, min_h, min_w] = np.min(np.array(non_zeros_idx), axis=1)
+        
+        if label is not None:
+            return volume[min_z:max_z, min_h:max_h, min_w:max_w], label[min_z:max_z, min_h:max_h, min_w:max_w]
+        else:
+            return volume[min_z:max_z, min_h:max_h, min_w:max_w]
+
+
+    def __random_center_crop__(self, data, label):
+        from random import random
+        """
+        Random crop
+        """
+        target_indexs = np.where(label>0)
+        [img_d, img_h, img_w] = data.shape
+        [max_D, max_H, max_W] = np.max(np.array(target_indexs), axis=1)
+        [min_D, min_H, min_W] = np.min(np.array(target_indexs), axis=1)
+        [target_depth, target_height, target_width] = np.array([max_D, max_H, max_W]) - np.array([min_D, min_H, min_W])
+        Z_min = int((min_D - target_depth*1.0/2) * random())
+        Y_min = int((min_H - target_height*1.0/2) * random())
+        X_min = int((min_W - target_width*1.0/2) * random())
+        
+        Z_max = int(img_d - ((img_d - (max_D + target_depth*1.0/2)) * random()))
+        Y_max = int(img_h - ((img_h - (max_H + target_height*1.0/2)) * random()))
+        X_max = int(img_w - ((img_w - (max_W + target_width*1.0/2)) * random()))
+       
+        Z_min = np.max([0, Z_min])
+        Y_min = np.max([0, Y_min])
+        X_min = np.max([0, X_min])
+
+        Z_max = np.min([img_d, Z_max])
+        Y_max = np.min([img_h, Y_max])
+        X_max = np.min([img_w, X_max])
+ 
+        Z_min = int(Z_min)
+        Y_min = int(Y_min)
+        X_min = int(X_min)
+        
+        Z_max = int(Z_max)
+        Y_max = int(Y_max)
+        X_max = int(X_max)
+
+        return data[Z_min: Z_max, Y_min: Y_max, X_min: X_max], label[Z_min: Z_max, Y_min: Y_max, X_min: X_max]
+
+
+
+    def __itensity_normalize_one_volume__(self, volume):
+        """
+        normalize the itensity of an nd volume based on the mean and std of nonzeor region
+        inputs:
+            volume: the input nd volume
+        outputs:
+            out: the normalized nd volume
+        """
+        
+        pixels = volume[volume > 0]
+        mean = pixels.mean()
+        std  = pixels.std()
+        out = (volume - mean)/std
+        out_random = np.random.normal(0, 1, size = volume.shape)
+        out[volume == 0] = out_random[volume == 0]
+        return out
+
+    def __resize_data__(self, data):
+        """
+        Resize the data to the input size
+        """ 
+        [depth, height, width] = data.shape
+        scale = [self.input_D*1.0/depth, self.input_H*1.0/height, self.input_W*1.0/width]  
+        data = ndimage.interpolation.zoom(data, scale, order=0)
+
+        return data
+
+
+    def __crop_data__(self, data, label):
+        """
+        Random crop with different methods:
+        """ 
+        # random center crop
+        data, label = self.__random_center_crop__ (data, label)
+        
+        return data, label
+
+    def __training_data_process__(self, data, label): 
+        # crop data according net input size
+        data = data.get_data()
+        label = label.get_data()
+        
+        # drop out the invalid range
+        data, label = self.__drop_invalid_range__(data, label)
+        
+        # crop data
+        data, label = self.__crop_data__(data, label) 
+
+        # resize data
+        data = self.__resize_data__(data)
+        label = self.__resize_data__(label)
+
+        # normalization datas
+        data = self.__itensity_normalize_one_volume__(data)
+
+        return data, label
+
+
+    def __testing_data_process__(self, data): 
+        # crop data according net input size
+        data = data.get_data()
+
+        # resize data
+        data = self.__resize_data__(data)
+
+        # normalization datas
+        data = self.__itensity_normalize_one_volume__(data)
+
+        return data