Diff of /rs_dataset.py [000000] .. [5ba3a6]

Switch to unified view

a b/rs_dataset.py
1
# -*- coding: utf-8 -*-
2
"""
3
@File    : rs_dataset.py
4
@Time    : 2019/6/22 10:57
5
@Author  : Parker
6
@Email   : now_cherish@163.com
7
@Software: PyCharm
8
@Des     : data set
9
"""
10
11
import csv
12
import torch
13
from torch.utils.data import Dataset
14
import torchvision.transforms as transforms
15
import pydicom
16
import os.path as osp
17
import os
18
from PIL import Image
19
import numpy as np
20
import random
21
import cv2
22
from tqdm import tqdm
23
import matplotlib.pyplot as plt
24
import time
25
from skimage.morphology import remove_small_holes, remove_small_objects
26
from skimage.measure import label, regionprops
27
from skimage.filters import threshold_otsu
28
29
def data_understanding():
30
    labels = prepare_label()
31
    s, ss = {}, {}
32
    for key, one in tqdm(zip(list(labels.keys()), list(labels.values()))):
33
        lb = int("".join(map(str, one)), 2)
34
        if lb not in s.keys():
35
            s[lb] = []
36
        s[lb].append(key)
37
38
    for one in labels.values():
39
        for idx, t in enumerate(one):
40
            if idx not in ss.keys():
41
                ss[idx] = 0
42
            if t == 1:
43
                ss[idx] += 1
44
    for one in s.keys():
45
        print(bin(one)[2:].zfill(6), len(s[one]))
46
47
48
def prepare_label():
49
    labels = ["epidural", "intraparenchymal", "intraventricular",
50
              "subarachnoid", "subdural", "any"]
51
    label_ranks = {}
52
    for i in range(len(labels)):
53
        label_ranks[labels[i]] = i
54
    all_true_labels = {}
55
56
    with open(osp.join('/media/tiger/zzr/rsna/stage_1_train.csv'), 'r') as fp:
57
        csv_reader = csv.reader(fp, delimiter=',')
58
        next(csv_reader, None)
59
        print('processing data ...')
60
        for row in tqdm(csv_reader):
61
            id = "_".join(row[0].split('_')[:2])
62
            label_id = label_ranks[row[0].split('_')[2]]
63
            if id not in all_true_labels:
64
                all_true_labels[id] = [0] * 6
65
            all_true_labels[id][label_id] = int(row[1])
66
67
    return all_true_labels
68
69
70
class RSDataset(Dataset):
71
    def __init__(self, rootpth='/media/tiger/zzr/rsna', des_size=(512, 512), mode='train'):
72
        """
73
        :param rootpth: 根目录
74
        :param re_size: 数据同一resize到这个尺寸再后处理
75
        :param crop_size: 剪切
76
        :param erase: 遮罩比例
77
        :param mode: train/val/test
78
        """
79
        self.root_path = rootpth
80
        self.des_size = des_size
81
        self.mode = mode
82
        self.name = None
83
84
        # 处理对应标签
85
        assert (mode == 'train' or mode == 'val' or mode == 'test')
86
        labels = ["epidural", "intraparenchymal", "intraventricular",
87
                  "subarachnoid", "subdural", "any"]
88
        self.label_ranks = {}
89
        for i in range(len(labels)):
90
            self.label_ranks[labels[i]] = i
91
        self.labels = self.prepare_label()
92
93
        # 读取文件名称
94
        self.file_names = []
95
        for root,dirs,names in os.walk(osp.join(rootpth, mode)):
96
            for name in names:
97
                if name == 'ID_6431af929.dcm':
98
                    continue
99
                self.file_names.append(osp.join(root,name))
100
101
        # 确定分隔符号
102
        self.split_char = '\\' if '\\' in self.file_names[0] else '/'
103
104
        # totensor 转换n
105
        self.to_tensor = transforms.Compose([ # 32.98408291578699 33.70147134726827
106
            transforms.ToTensor(),
107
            transforms.Normalize(32.98408291578699, 33.70147134726827)
108
        ])
109
110
    def data_loader(self, fname):
111
        """
112
        load data
113
        :param fname:
114
        :return:
115
        """
116
        ds = pydicom.dcmread(fname)
117
        try:
118
            windowCenter = int(ds.WindowCenter[0])
119
            windowWidth = int(ds.WindowWidth[0])
120
        except:
121
            windowCenter = int(ds.WindowCenter)
122
            windowWidth = int(ds.WindowWidth)
123
        intercept = ds.RescaleIntercept
124
        slope = ds.RescaleSlope
125
        data = ds.pixel_array
126
        data = np.clip(data * slope + intercept, windowCenter - windowWidth / 2, windowCenter + windowWidth / 2).astype(np.float32)
127
        data = self.preprocess(data)
128
        return data
129
130
    def preprocess(self, data):
131
        """
132
        otsu threshold
133
        :param data:
134
        :return:
135
        """
136
        try:
137
            thres = threshold_otsu(data)
138
        except:
139
            thres = np.min(data)
140
141
        data1 = data > thres
142
        data1 = remove_small_objects(data1)
143
        label_data = label(data1)
144
        props = regionprops(label_data)
145
        area = 0
146
        bbox = (0, 0, np.shape(data)[0], np.shape(data)[1])
147
        for idx, i in enumerate(props):
148
            if i.area > area:
149
                area = i.area
150
                bbox = i.bbox
151
152
        data1 = data[bbox[0]:bbox[2]+1, bbox[1]:bbox[-1]+1]
153
154
        return data1
155
156
    def prepare_label(self):
157
        all_true_labels = {}
158
        import csv
159
        with open(osp.join(self.root_path, 'stage_1_train.csv'), 'r') as fp:
160
            csv_reader = csv.reader(fp, delimiter=',')
161
            next(csv_reader, None)
162
            for row in tqdm(csv_reader):
163
                id = "_".join(row[0].split('_')[:2])
164
                label_id = self.label_ranks[row[0].split('_')[2]]
165
                if id not in all_true_labels:
166
                    all_true_labels[id] = [0] * 6
167
                all_true_labels[id][label_id] = float(row[1])
168
169
        return all_true_labels
170
171
    def __getitem__(self, idx):
172
        self.name = self.file_names[idx]
173
        category = self.labels[self.name.split(self.split_char)[-1].split('.')[0]]
174
        img = cv2.resize(self.data_loader(self.name), dsize=self.des_size, interpolation=cv2.INTER_LINEAR)
175
        # plt.imshow(img)
176
        # plt.show()
177
        return self.to_tensor(img), torch.tensor(category)
178
179
    def __len__(self):
180
        return len(self.file_names)
181
182
    def calculateMeanStd(self, idx):
183
        """
184
185
        :param idx:
186
        :return:
187
        """
188
        self.name = self.file_names[idx]
189
        img = self.data_loader(self.name)
190
191
        return np.mean(img), np.std(img)
192
193
194
class RSDataset_test(RSDataset):
195
    def __init__(self, rootpth='/media/tiger/zzr/rsna', des_size=(512, 512), mode='test'):
196
        super().__init__()
197
        # 读取文件名称
198
        self.file_names = []
199
        for root, dirs, names in os.walk(osp.join(rootpth, mode)):
200
            for name in names:
201
                self.file_names.append(osp.join(root, name))
202
203
    def __getitem__(self, idx):
204
        self.name = self.file_names[idx]
205
        img = cv2.resize(self.data_loader(self.name), dsize=self.des_size, interpolation=cv2.INTER_LINEAR)
206
        return self.to_tensor(img), self.name.split(self.split_char)[-1].split('.')[0]
207
208
    def __len__(self):
209
        return len(self.file_names)
210
211
212
if __name__ == '__main__':
213
    data = RSDataset_test()
214
    for i in tqdm(range(len(data))):
215
        a, b = data.__getitem__(i)
216
        print(data.name)
217
        print(b)
218
219
    # mean, std = 0, 0
220
    # for i in tqdm(range(len(data))):
221
    #     u, d = data.calculateMeanStd(i)
222
    #     u /= len(data)
223
    #     d /= len(data)
224
    #     mean += u
225
    #     std += d
226
    #
227
    # print(mean, std)