|
a |
|
b/rs_dataset.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
""" |
|
|
3 |
@File : rs_dataset.py |
|
|
4 |
@Time : 2019/6/22 10:57 |
|
|
5 |
@Author : Parker |
|
|
6 |
@Email : now_cherish@163.com |
|
|
7 |
@Software: PyCharm |
|
|
8 |
@Des : data set |
|
|
9 |
""" |
|
|
10 |
|
|
|
11 |
import csv |
|
|
12 |
import torch |
|
|
13 |
from torch.utils.data import Dataset |
|
|
14 |
import torchvision.transforms as transforms |
|
|
15 |
import pydicom |
|
|
16 |
import os.path as osp |
|
|
17 |
import os |
|
|
18 |
from PIL import Image |
|
|
19 |
import numpy as np |
|
|
20 |
import random |
|
|
21 |
import cv2 |
|
|
22 |
from tqdm import tqdm |
|
|
23 |
import matplotlib.pyplot as plt |
|
|
24 |
import time |
|
|
25 |
from skimage.morphology import remove_small_holes, remove_small_objects |
|
|
26 |
from skimage.measure import label, regionprops |
|
|
27 |
from skimage.filters import threshold_otsu |
|
|
28 |
|
|
|
29 |
def data_understanding(): |
|
|
30 |
labels = prepare_label() |
|
|
31 |
s, ss = {}, {} |
|
|
32 |
for key, one in tqdm(zip(list(labels.keys()), list(labels.values()))): |
|
|
33 |
lb = int("".join(map(str, one)), 2) |
|
|
34 |
if lb not in s.keys(): |
|
|
35 |
s[lb] = [] |
|
|
36 |
s[lb].append(key) |
|
|
37 |
|
|
|
38 |
for one in labels.values(): |
|
|
39 |
for idx, t in enumerate(one): |
|
|
40 |
if idx not in ss.keys(): |
|
|
41 |
ss[idx] = 0 |
|
|
42 |
if t == 1: |
|
|
43 |
ss[idx] += 1 |
|
|
44 |
for one in s.keys(): |
|
|
45 |
print(bin(one)[2:].zfill(6), len(s[one])) |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
def prepare_label(): |
|
|
49 |
labels = ["epidural", "intraparenchymal", "intraventricular", |
|
|
50 |
"subarachnoid", "subdural", "any"] |
|
|
51 |
label_ranks = {} |
|
|
52 |
for i in range(len(labels)): |
|
|
53 |
label_ranks[labels[i]] = i |
|
|
54 |
all_true_labels = {} |
|
|
55 |
|
|
|
56 |
with open(osp.join('/media/tiger/zzr/rsna/stage_1_train.csv'), 'r') as fp: |
|
|
57 |
csv_reader = csv.reader(fp, delimiter=',') |
|
|
58 |
next(csv_reader, None) |
|
|
59 |
print('processing data ...') |
|
|
60 |
for row in tqdm(csv_reader): |
|
|
61 |
id = "_".join(row[0].split('_')[:2]) |
|
|
62 |
label_id = label_ranks[row[0].split('_')[2]] |
|
|
63 |
if id not in all_true_labels: |
|
|
64 |
all_true_labels[id] = [0] * 6 |
|
|
65 |
all_true_labels[id][label_id] = int(row[1]) |
|
|
66 |
|
|
|
67 |
return all_true_labels |
|
|
68 |
|
|
|
69 |
|
|
|
70 |
class RSDataset(Dataset): |
|
|
71 |
def __init__(self, rootpth='/media/tiger/zzr/rsna', des_size=(512, 512), mode='train'): |
|
|
72 |
""" |
|
|
73 |
:param rootpth: 根目录 |
|
|
74 |
:param re_size: 数据同一resize到这个尺寸再后处理 |
|
|
75 |
:param crop_size: 剪切 |
|
|
76 |
:param erase: 遮罩比例 |
|
|
77 |
:param mode: train/val/test |
|
|
78 |
""" |
|
|
79 |
self.root_path = rootpth |
|
|
80 |
self.des_size = des_size |
|
|
81 |
self.mode = mode |
|
|
82 |
self.name = None |
|
|
83 |
|
|
|
84 |
# 处理对应标签 |
|
|
85 |
assert (mode == 'train' or mode == 'val' or mode == 'test') |
|
|
86 |
labels = ["epidural", "intraparenchymal", "intraventricular", |
|
|
87 |
"subarachnoid", "subdural", "any"] |
|
|
88 |
self.label_ranks = {} |
|
|
89 |
for i in range(len(labels)): |
|
|
90 |
self.label_ranks[labels[i]] = i |
|
|
91 |
self.labels = self.prepare_label() |
|
|
92 |
|
|
|
93 |
# 读取文件名称 |
|
|
94 |
self.file_names = [] |
|
|
95 |
for root,dirs,names in os.walk(osp.join(rootpth, mode)): |
|
|
96 |
for name in names: |
|
|
97 |
if name == 'ID_6431af929.dcm': |
|
|
98 |
continue |
|
|
99 |
self.file_names.append(osp.join(root,name)) |
|
|
100 |
|
|
|
101 |
# 确定分隔符号 |
|
|
102 |
self.split_char = '\\' if '\\' in self.file_names[0] else '/' |
|
|
103 |
|
|
|
104 |
# totensor 转换n |
|
|
105 |
self.to_tensor = transforms.Compose([ # 32.98408291578699 33.70147134726827 |
|
|
106 |
transforms.ToTensor(), |
|
|
107 |
transforms.Normalize(32.98408291578699, 33.70147134726827) |
|
|
108 |
]) |
|
|
109 |
|
|
|
110 |
def data_loader(self, fname): |
|
|
111 |
""" |
|
|
112 |
load data |
|
|
113 |
:param fname: |
|
|
114 |
:return: |
|
|
115 |
""" |
|
|
116 |
ds = pydicom.dcmread(fname) |
|
|
117 |
try: |
|
|
118 |
windowCenter = int(ds.WindowCenter[0]) |
|
|
119 |
windowWidth = int(ds.WindowWidth[0]) |
|
|
120 |
except: |
|
|
121 |
windowCenter = int(ds.WindowCenter) |
|
|
122 |
windowWidth = int(ds.WindowWidth) |
|
|
123 |
intercept = ds.RescaleIntercept |
|
|
124 |
slope = ds.RescaleSlope |
|
|
125 |
data = ds.pixel_array |
|
|
126 |
data = np.clip(data * slope + intercept, windowCenter - windowWidth / 2, windowCenter + windowWidth / 2).astype(np.float32) |
|
|
127 |
data = self.preprocess(data) |
|
|
128 |
return data |
|
|
129 |
|
|
|
130 |
def preprocess(self, data): |
|
|
131 |
""" |
|
|
132 |
otsu threshold |
|
|
133 |
:param data: |
|
|
134 |
:return: |
|
|
135 |
""" |
|
|
136 |
try: |
|
|
137 |
thres = threshold_otsu(data) |
|
|
138 |
except: |
|
|
139 |
thres = np.min(data) |
|
|
140 |
|
|
|
141 |
data1 = data > thres |
|
|
142 |
data1 = remove_small_objects(data1) |
|
|
143 |
label_data = label(data1) |
|
|
144 |
props = regionprops(label_data) |
|
|
145 |
area = 0 |
|
|
146 |
bbox = (0, 0, np.shape(data)[0], np.shape(data)[1]) |
|
|
147 |
for idx, i in enumerate(props): |
|
|
148 |
if i.area > area: |
|
|
149 |
area = i.area |
|
|
150 |
bbox = i.bbox |
|
|
151 |
|
|
|
152 |
data1 = data[bbox[0]:bbox[2]+1, bbox[1]:bbox[-1]+1] |
|
|
153 |
|
|
|
154 |
return data1 |
|
|
155 |
|
|
|
156 |
def prepare_label(self): |
|
|
157 |
all_true_labels = {} |
|
|
158 |
import csv |
|
|
159 |
with open(osp.join(self.root_path, 'stage_1_train.csv'), 'r') as fp: |
|
|
160 |
csv_reader = csv.reader(fp, delimiter=',') |
|
|
161 |
next(csv_reader, None) |
|
|
162 |
for row in tqdm(csv_reader): |
|
|
163 |
id = "_".join(row[0].split('_')[:2]) |
|
|
164 |
label_id = self.label_ranks[row[0].split('_')[2]] |
|
|
165 |
if id not in all_true_labels: |
|
|
166 |
all_true_labels[id] = [0] * 6 |
|
|
167 |
all_true_labels[id][label_id] = float(row[1]) |
|
|
168 |
|
|
|
169 |
return all_true_labels |
|
|
170 |
|
|
|
171 |
def __getitem__(self, idx): |
|
|
172 |
self.name = self.file_names[idx] |
|
|
173 |
category = self.labels[self.name.split(self.split_char)[-1].split('.')[0]] |
|
|
174 |
img = cv2.resize(self.data_loader(self.name), dsize=self.des_size, interpolation=cv2.INTER_LINEAR) |
|
|
175 |
# plt.imshow(img) |
|
|
176 |
# plt.show() |
|
|
177 |
return self.to_tensor(img), torch.tensor(category) |
|
|
178 |
|
|
|
179 |
def __len__(self): |
|
|
180 |
return len(self.file_names) |
|
|
181 |
|
|
|
182 |
def calculateMeanStd(self, idx): |
|
|
183 |
""" |
|
|
184 |
|
|
|
185 |
:param idx: |
|
|
186 |
:return: |
|
|
187 |
""" |
|
|
188 |
self.name = self.file_names[idx] |
|
|
189 |
img = self.data_loader(self.name) |
|
|
190 |
|
|
|
191 |
return np.mean(img), np.std(img) |
|
|
192 |
|
|
|
193 |
|
|
|
194 |
class RSDataset_test(RSDataset): |
|
|
195 |
def __init__(self, rootpth='/media/tiger/zzr/rsna', des_size=(512, 512), mode='test'): |
|
|
196 |
super().__init__() |
|
|
197 |
# 读取文件名称 |
|
|
198 |
self.file_names = [] |
|
|
199 |
for root, dirs, names in os.walk(osp.join(rootpth, mode)): |
|
|
200 |
for name in names: |
|
|
201 |
self.file_names.append(osp.join(root, name)) |
|
|
202 |
|
|
|
203 |
def __getitem__(self, idx): |
|
|
204 |
self.name = self.file_names[idx] |
|
|
205 |
img = cv2.resize(self.data_loader(self.name), dsize=self.des_size, interpolation=cv2.INTER_LINEAR) |
|
|
206 |
return self.to_tensor(img), self.name.split(self.split_char)[-1].split('.')[0] |
|
|
207 |
|
|
|
208 |
def __len__(self): |
|
|
209 |
return len(self.file_names) |
|
|
210 |
|
|
|
211 |
|
|
|
212 |
if __name__ == '__main__': |
|
|
213 |
data = RSDataset_test() |
|
|
214 |
for i in tqdm(range(len(data))): |
|
|
215 |
a, b = data.__getitem__(i) |
|
|
216 |
print(data.name) |
|
|
217 |
print(b) |
|
|
218 |
|
|
|
219 |
# mean, std = 0, 0 |
|
|
220 |
# for i in tqdm(range(len(data))): |
|
|
221 |
# u, d = data.calculateMeanStd(i) |
|
|
222 |
# u /= len(data) |
|
|
223 |
# d /= len(data) |
|
|
224 |
# mean += u |
|
|
225 |
# std += d |
|
|
226 |
# |
|
|
227 |
# print(mean, std) |