Diff of /data.py [000000] .. [48d89d]

Switch to unified view

a b/data.py
1
"""
2
This code is to
3
1. Create train & test input to Network as numpy arrays
4
2. Load the train & test numpy arrays
5
"""
6
7
import numpy as np
8
from utils import *
9
10
# data type to save as np array
11
npdtype = np.float32
12
13
14
def create_train_data(current_fold, plane):
15
    """
16
    Crop each slice by its ground truth bounding box,
17
    then pad zeros to form uniform dimension,
18
    rescale pixel intensities to [0,1]
19
    """
20
    # get the list of image and label number of current_fold
21
    imlb_list = open(training_set_filename(current_fold), 'r').read().splitlines()
22
    current_fold = current_fold
23
    training_image_set = np.zeros((len(imlb_list)), dtype = np.int)
24
25
    for i in range(len(imlb_list)):
26
        s = imlb_list[i].split(' ')
27
        training_image_set[i] = int(s[0])
28
29
    slice_list = open(list_training[plane], 'r').read().splitlines()
30
    slices = len( slice_list)
31
    image_ID = np.zeros(( slices), dtype = np.int)
32
    slice_ID = np.zeros(( slices), dtype = np.int)
33
    image_filename = ['' for l in range( slices)]
34
    label_filename = ['' for l in range( slices)]
35
    pixels = np.zeros((slices), dtype = np.int)
36
37
    for l in range(slices):
38
        s =  slice_list[l].split(' ')
39
        image_ID[l] = s[0]
40
        slice_ID[l] = s[1]
41
        image_filename[l] = s[2]
42
        label_filename[l] = s[3]
43
        pixels[l] = int(s[organ_ID * 5])
44
45
    create_slice_list = []
46
    create_label_list = []
47
48
    for l in range(slices):
49
        if image_ID[l] in training_image_set and pixels[l] >= 100:
50
            create_slice_list.append(image_filename[l])
51
            create_label_list.append(label_filename[l])
52
    if len(create_slice_list)!= len(create_label_list):
53
        raise ValueError('slice number does not equal label number!')
54
55
    total = len(create_slice_list)
56
57
    img_rows = XMAX
58
    img_cols = YMAX
59
60
    imgs = np.ndarray((total, img_rows, img_cols), dtype = npdtype)
61
    imgs_mask = np.ndarray((total, img_rows, img_cols), dtype = npdtype)
62
63
    print('-'*30)
64
    print('  Creating training data...')
65
    print('-'*30)
66
67
    for i in range(len(create_slice_list)):
68
        cur_im = np.load(create_slice_list[i])
69
        cur_mask = np.load(create_label_list[i])
70
71
        cur_im = (cur_im - low_range) / float(high_range - low_range)
72
        arr = np.nonzero(cur_mask)
73
74
        width = cur_mask.shape[0]
75
        height = cur_mask.shape[1]
76
77
        minA = min(arr[0])
78
        maxA = max(arr[0])
79
        minB = min(arr[1])
80
        maxB = max(arr[1])
81
82
        # with margin
83
        cropped_im = cur_im[max(minA - margin, 0): min(maxA + margin + 1, width), \
84
                                    max(minB - margin, 0): min(maxB + margin + 1, height)]
85
        cropped_mask = cur_mask[max(minA - margin, 0): min(maxA + margin + 1, width), \
86
                                    max(minB - margin, 0): min(maxB + margin + 1, height)]
87
88
        imgs[i] = pad_2d(cropped_im, plane, 0, XMAX, YMAX, ZMAX)
89
        imgs_mask[i] = pad_2d(cropped_mask, plane, 0, XMAX, YMAX, ZMAX)
90
91
        if i % 100 == 0:
92
            print('Done: {0}/{1} slices'.format(i, total))
93
94
    np.save('imgs_train_%s_%s.npy'%(current_fold, plane), imgs)
95
    np.save('masks_train_%s_%s.npy'%(current_fold, plane), imgs_mask)
96
    print('Training data created for fold %s, plane %s'%(current_fold, plane))
97
98
99
def load_train_data(current_fold, plane):
100
    imgs_train = np.load('imgs_train_%s_%s.npy'%(current_fold, plane))
101
    mask_train = np.load('masks_train_%s_%s.npy'%(current_fold, plane))
102
    return imgs_train, mask_train
103
104
105
def load_test_data(current_fold, plane):
106
    imgs_test = np.load('imgs_test_%s_%s.npy'%(current_fold, plane))
107
    mask_test = np.load('masks_test_$s_%s.npy'%(current_fold, plane))
108
    return imgs_test, mask_test
109
110
111
if __name__ == '__main__':
112
113
    data_path = sys.argv[1]
114
    current_fold = int(sys.argv[2])
115
    plane = sys.argv[3]
116
117
    # dim of each case (after padding zeors to max gt bounding box)
118
    ZMAX = int(sys.argv[4])
119
    YMAX = int(sys.argv[5])
120
    XMAX = int(sys.argv[6])
121
122
    margin = int(sys.argv[7])
123
    organ_ID = int(sys.argv[8])
124
    low_range = int(sys.argv[9])
125
    high_range = int(sys.argv[10])
126
127
    create_train_data(current_fold, plane)