Diff of /rs_dataset.py [000000] .. [5ba3a6]

Switch to side-by-side view

--- a
+++ b/rs_dataset.py
@@ -0,0 +1,227 @@
+# -*- coding: utf-8 -*-
+"""
+@File    : rs_dataset.py
+@Time    : 2019/6/22 10:57
+@Author  : Parker
+@Email   : now_cherish@163.com
+@Software: PyCharm
+@Des     : data set
+"""
+
+import csv
+import torch
+from torch.utils.data import Dataset
+import torchvision.transforms as transforms
+import pydicom
+import os.path as osp
+import os
+from PIL import Image
+import numpy as np
+import random
+import cv2
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+import time
+from skimage.morphology import remove_small_holes, remove_small_objects
+from skimage.measure import label, regionprops
+from skimage.filters import threshold_otsu
+
+def data_understanding():
+    labels = prepare_label()
+    s, ss = {}, {}
+    for key, one in tqdm(zip(list(labels.keys()), list(labels.values()))):
+        lb = int("".join(map(str, one)), 2)
+        if lb not in s.keys():
+            s[lb] = []
+        s[lb].append(key)
+
+    for one in labels.values():
+        for idx, t in enumerate(one):
+            if idx not in ss.keys():
+                ss[idx] = 0
+            if t == 1:
+                ss[idx] += 1
+    for one in s.keys():
+        print(bin(one)[2:].zfill(6), len(s[one]))
+
+
+def prepare_label():
+    labels = ["epidural", "intraparenchymal", "intraventricular",
+              "subarachnoid", "subdural", "any"]
+    label_ranks = {}
+    for i in range(len(labels)):
+        label_ranks[labels[i]] = i
+    all_true_labels = {}
+
+    with open(osp.join('/media/tiger/zzr/rsna/stage_1_train.csv'), 'r') as fp:
+        csv_reader = csv.reader(fp, delimiter=',')
+        next(csv_reader, None)
+        print('processing data ...')
+        for row in tqdm(csv_reader):
+            id = "_".join(row[0].split('_')[:2])
+            label_id = label_ranks[row[0].split('_')[2]]
+            if id not in all_true_labels:
+                all_true_labels[id] = [0] * 6
+            all_true_labels[id][label_id] = int(row[1])
+
+    return all_true_labels
+
+
+class RSDataset(Dataset):
+    def __init__(self, rootpth='/media/tiger/zzr/rsna', des_size=(512, 512), mode='train'):
+        """
+        :param rootpth: 根目录
+        :param re_size: 数据同一resize到这个尺寸再后处理
+        :param crop_size: 剪切
+        :param erase: 遮罩比例
+        :param mode: train/val/test
+        """
+        self.root_path = rootpth
+        self.des_size = des_size
+        self.mode = mode
+        self.name = None
+
+        # 处理对应标签
+        assert (mode == 'train' or mode == 'val' or mode == 'test')
+        labels = ["epidural", "intraparenchymal", "intraventricular",
+                  "subarachnoid", "subdural", "any"]
+        self.label_ranks = {}
+        for i in range(len(labels)):
+            self.label_ranks[labels[i]] = i
+        self.labels = self.prepare_label()
+
+        # 读取文件名称
+        self.file_names = []
+        for root,dirs,names in os.walk(osp.join(rootpth, mode)):
+            for name in names:
+                if name == 'ID_6431af929.dcm':
+                    continue
+                self.file_names.append(osp.join(root,name))
+
+        # 确定分隔符号
+        self.split_char = '\\' if '\\' in self.file_names[0] else '/'
+
+        # totensor 转换n
+        self.to_tensor = transforms.Compose([ # 32.98408291578699 33.70147134726827
+            transforms.ToTensor(),
+            transforms.Normalize(32.98408291578699, 33.70147134726827)
+        ])
+
+    def data_loader(self, fname):
+        """
+        load data
+        :param fname:
+        :return:
+        """
+        ds = pydicom.dcmread(fname)
+        try:
+            windowCenter = int(ds.WindowCenter[0])
+            windowWidth = int(ds.WindowWidth[0])
+        except:
+            windowCenter = int(ds.WindowCenter)
+            windowWidth = int(ds.WindowWidth)
+        intercept = ds.RescaleIntercept
+        slope = ds.RescaleSlope
+        data = ds.pixel_array
+        data = np.clip(data * slope + intercept, windowCenter - windowWidth / 2, windowCenter + windowWidth / 2).astype(np.float32)
+        data = self.preprocess(data)
+        return data
+
+    def preprocess(self, data):
+        """
+        otsu threshold
+        :param data:
+        :return:
+        """
+        try:
+            thres = threshold_otsu(data)
+        except:
+            thres = np.min(data)
+
+        data1 = data > thres
+        data1 = remove_small_objects(data1)
+        label_data = label(data1)
+        props = regionprops(label_data)
+        area = 0
+        bbox = (0, 0, np.shape(data)[0], np.shape(data)[1])
+        for idx, i in enumerate(props):
+            if i.area > area:
+                area = i.area
+                bbox = i.bbox
+
+        data1 = data[bbox[0]:bbox[2]+1, bbox[1]:bbox[-1]+1]
+
+        return data1
+
+    def prepare_label(self):
+        all_true_labels = {}
+        import csv
+        with open(osp.join(self.root_path, 'stage_1_train.csv'), 'r') as fp:
+            csv_reader = csv.reader(fp, delimiter=',')
+            next(csv_reader, None)
+            for row in tqdm(csv_reader):
+                id = "_".join(row[0].split('_')[:2])
+                label_id = self.label_ranks[row[0].split('_')[2]]
+                if id not in all_true_labels:
+                    all_true_labels[id] = [0] * 6
+                all_true_labels[id][label_id] = float(row[1])
+
+        return all_true_labels
+
+    def __getitem__(self, idx):
+        self.name = self.file_names[idx]
+        category = self.labels[self.name.split(self.split_char)[-1].split('.')[0]]
+        img = cv2.resize(self.data_loader(self.name), dsize=self.des_size, interpolation=cv2.INTER_LINEAR)
+        # plt.imshow(img)
+        # plt.show()
+        return self.to_tensor(img), torch.tensor(category)
+
+    def __len__(self):
+        return len(self.file_names)
+
+    def calculateMeanStd(self, idx):
+        """
+
+        :param idx:
+        :return:
+        """
+        self.name = self.file_names[idx]
+        img = self.data_loader(self.name)
+
+        return np.mean(img), np.std(img)
+
+
+class RSDataset_test(RSDataset):
+    def __init__(self, rootpth='/media/tiger/zzr/rsna', des_size=(512, 512), mode='test'):
+        super().__init__()
+        # 读取文件名称
+        self.file_names = []
+        for root, dirs, names in os.walk(osp.join(rootpth, mode)):
+            for name in names:
+                self.file_names.append(osp.join(root, name))
+
+    def __getitem__(self, idx):
+        self.name = self.file_names[idx]
+        img = cv2.resize(self.data_loader(self.name), dsize=self.des_size, interpolation=cv2.INTER_LINEAR)
+        return self.to_tensor(img), self.name.split(self.split_char)[-1].split('.')[0]
+
+    def __len__(self):
+        return len(self.file_names)
+
+
+if __name__ == '__main__':
+    data = RSDataset_test()
+    for i in tqdm(range(len(data))):
+        a, b = data.__getitem__(i)
+        print(data.name)
+        print(b)
+
+    # mean, std = 0, 0
+    # for i in tqdm(range(len(data))):
+    #     u, d = data.calculateMeanStd(i)
+    #     u /= len(data)
+    #     d /= len(data)
+    #     mean += u
+    #     std += d
+    #
+    # print(mean, std)