Switch to unified view

a b/dataprocess/dataprocess.py
1
import random
2
from .segdataloader import *
3
from .utils import *
4
import csv
5
import glob 
6
import cv2
7
8
fold = 1
9
10
def readCSV(filename):
11
    lines = []
12
    with open(filename, "r") as f:
13
        csvreader = csv.reader(f)
14
        for line in csvreader:
15
            lines.append(line[0])
16
    return lines
17
18
19
def get_dataloader(config, mode='train', batchsize=64, width=64, height=64):
20
21
    train_datas = []
22
    train_masks = []
23
    for index in config.training_fold_index:
24
        tempdata = readCSV(os.path.join(config.csvPath, 'data_fold' + str(index) + '.csv'))
25
        tempmask = readCSV(os.path.join(config.csvPath, 'mask_fold' + str(index) + '.csv'))
26
27
        train_datas += tempdata
28
        train_masks += tempmask
29
30
    test_datas = readCSV(os.path.join(config.csvPath, 'data_fold' + str(config.test_fold_index[0]) + '.csv'))
31
    test_masks = readCSV(os.path.join(config.csvPath, 'mask_fold' + str(config.test_fold_index[0]) + '.csv'))
32
33
    
34
    if mode=='train':
35
        # remove features labels
36
        temp_train_datas = []
37
        for one in train_datas:
38
            one_temp = one.split('/')[-1]
39
            one_list = one_temp.split('_')
40
            temp_train_datas.append(one_list[0] + '_' + one_list[1] + '_' + one_list[2])
41
        temp_test_datas = []
42
        for one in test_datas:
43
            one_temp = one.split('/')[-1]
44
            one_list = one_temp.split('_')
45
            temp_test_datas.append(one_list[0] + '_' + one_list[1] + '_' + one_list[2])
46
47
        mid_files = os.listdir(config.maskPath2)
48
49
50
        temp2_train_inter = []
51
        temp2_train_union = []
52
        temp2_train_data = []
53
        temp2_train_lung = []
54
        temp2_train_media = []
55
        temp2_train_mask = []
56
57
        for one_train_data in temp_train_datas:
58
            imagename = one_train_data + '.png'
59
60
            if imagename in mid_files:
61
                innertemp0 = config.midPath + one_train_data + '.npy'
62
                innertemp1 = config.lungPath + one_train_data + '_lung.npy'
63
                innertemp2 = config.mediaPath + one_train_data + '_mediastinal.npy'
64
                innertemp3 = config.maskPath2 + one_train_data + '_red.png'
65
                innertemp4 = config.maskPath2 + one_train_data + '_blue.png'
66
                innertemp5 = config.maskPath1 + 'mid_' + one_train_data + '_mask.png'
67
                temp2_train_data.append(innertemp0)
68
                temp2_train_lung.append(innertemp1)
69
                temp2_train_media.append(innertemp2)
70
                temp2_train_union.append(innertemp3) 
71
                temp2_train_inter.append(innertemp4)
72
                temp2_train_mask.append(innertemp5)
73
74
75
        temp2_test_data = []
76
        temp2_test_lung = []
77
        temp2_test_media = []
78
        temp2_test_inter = []
79
        temp2_test_union = []
80
        temp2_test_mask = []
81
        for one_test_data in temp_test_datas:
82
            imagename = one_test_data + '.png'       
83
            if imagename in mid_files:
84
                innertemp0 = config.midPath + one_test_data + '.npy' 
85
                innertemp1 = config.lungPath + one_test_data + '_lung.npy'
86
                innertemp2 = config.mediaPath + one_test_data + '_mediastinal.npy'
87
                innertemp3 = config.maskPath2 + one_test_data + '_red.png'
88
                innertemp4 = config.maskPath2 + one_test_data + '_blue.png'
89
                innertemp5 = config.maskPath1 + 'mid_' + one_test_data + '_mask.png'
90
                temp2_test_data.append(innertemp0)
91
                temp2_test_lung.append(innertemp1)
92
                temp2_test_media.append(innertemp2)
93
                temp2_test_union.append(innertemp3) 
94
                temp2_test_inter.append(innertemp4)
95
                temp2_test_mask.append(innertemp5)
96
97
98
        print('***********')
99
        print('the length of train data: ', len(temp2_train_data))
100
        print('the length of test data: ', len(temp2_test_data))
101
        print('-----------')
102
        dataloader = loader(Dataset(temp2_train_data, temp2_train_lung, temp2_train_media, temp2_train_inter, temp2_train_union, temp2_train_mask,  width=width, height=height), batchsize)
103
        dataloader_val = loader(Dataset(temp2_test_data, temp2_test_lung, temp2_test_media, temp2_test_inter, temp2_test_union, temp2_test_mask, width=width, height=height), batchsize)
104
        return dataloader, dataloader_val
105
106
    if mode=='row':
107
        # remove features labels
108
        temp_train_datas = []
109
        for one in train_datas:
110
            one_temp = one.split('/')[-1]
111
            one_list = one_temp.split('_')
112
            temp_train_datas.append('mid_' + one_list[0] + '_' + one_list[1] + '_' + one_list[2])
113
        temp_test_datas = []
114
        for one in test_datas:
115
            one_temp = one.split('/')[-1]
116
            one_list = one_temp.split('_')
117
            temp_test_datas.append('mid_' + one_list[0] + '_' + one_list[1] + '_' + one_list[2])
118
        temp2_train_datas = []
119
        temp2_train_masks = []
120
        temp2_test_datas = []
121
        temp2_test_masks = []
122
        row_files = os.listdir(config.rowPath)
123
        for one_train_data in temp_train_datas:
124
            imagename = one_train_data + '.png'
125
            if imagename in row_files:
126
                innertemp0 = config.rowPath + one_train_data + '.png'
127
                innertemp1 = config.rowPath + one_train_data + '_mask.png'
128
                temp2_train_datas.append(innertemp0)
129
                temp2_train_masks.append(innertemp1)
130
        for one_test_data in temp_test_datas:
131
            imagename = one_test_data + '.png'
132
            if imagename in row_files:
133
                innertemp0 = config.rowPath + one_test_data + '.png'
134
                innertemp1 = config.rowPath + one_test_data + '_mask.png'
135
                temp2_test_datas.append(innertemp0)
136
                temp2_test_masks.append(innertemp1)
137
138
        dataloader = loader(RowDataset(temp2_train_datas, temp2_train_masks, width=width, height=height), batchsize)
139
        dataloader_val = loader(RowDataset(temp2_test_datas, temp2_test_masks, width=width, height=height), batchsize)
140
141
        return dataloader, dataloader_val
142