|
a |
|
b/datasets/dataset_segmentation.py |
|
|
1 |
import os |
|
|
2 |
import re |
|
|
3 |
import numpy as np |
|
|
4 |
from PIL import Image as PILImage |
|
|
5 |
import torch |
|
|
6 |
import torch.nn.functional as F |
|
|
7 |
from torch.utils import data as data |
|
|
8 |
from torchvision import transforms as transforms |
|
|
9 |
import matplotlib.pyplot as plt |
|
|
10 |
from skimage.measure import label as method_label |
|
|
11 |
from skimage.measure import regionprops |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
# dataset for GGO and C segmentation |
|
|
15 |
class CovidCTData(data.Dataset): |
|
|
16 |
|
|
|
17 |
def __init__(self, **kwargs): |
|
|
18 |
self.mask_type = kwargs['mask_type'] |
|
|
19 |
self.ignore_ = kwargs['ignore_small'] |
|
|
20 |
# ignore small areas? |
|
|
21 |
if self.ignore_: |
|
|
22 |
self.area_th = 100 |
|
|
23 |
else: |
|
|
24 |
self.area_th = 1 |
|
|
25 |
self.stage = kwargs['stage'] |
|
|
26 |
# this returns the path to imgs dir |
|
|
27 |
self.data = kwargs['data'] |
|
|
28 |
# this returns the path to |
|
|
29 |
self.gt = kwargs['gt'] |
|
|
30 |
# IMPORTANT: the order of images and masks must be the same |
|
|
31 |
self.sorted_data = sorted(os.listdir(self.data)) |
|
|
32 |
self.sorted_gt = sorted(os.listdir(self.gt)) |
|
|
33 |
self.fname = None |
|
|
34 |
self.img_fname = None |
|
|
35 |
|
|
|
36 |
# this method normalizes the image and converts it to Pytorch tensor |
|
|
37 |
# Here we use pytorch transforms functionality, and Compose them together, |
|
|
38 |
def transform_img(self, img): |
|
|
39 |
# Faster R-CNN does the normalization |
|
|
40 |
t_ = transforms.Compose([ |
|
|
41 |
# transforms.ToPILImage(), |
|
|
42 |
# transforms.Resize(img_size), |
|
|
43 |
transforms.ToTensor(), |
|
|
44 |
]) |
|
|
45 |
img = t_(img) |
|
|
46 |
return img |
|
|
47 |
|
|
|
48 |
# inputs: box coords (min_row, min_col, max_row, max_col) |
|
|
49 |
# array HxW from whic to extract a single object's mask |
|
|
50 |
# each isolated mask should have a different label, lab>0 |
|
|
51 |
# masks are binary uint8 type |
|
|
52 |
def extract_single_mask(self, mask, lab): |
|
|
53 |
_mask = np.zeros(mask.shape, dtype=np.uint8) |
|
|
54 |
area = mask == lab |
|
|
55 |
_mask[area] = 1 |
|
|
56 |
return _mask |
|
|
57 |
|
|
|
58 |
def load_img(self, idx): |
|
|
59 |
im = PILImage.open(os.path.join(self.data, self.sorted_data[idx])) |
|
|
60 |
self.img_fname = os.path.join(self.data, self.sorted_data[idx]) |
|
|
61 |
im = self.transform_img(im) |
|
|
62 |
return im |
|
|
63 |
|
|
|
64 |
def load_labels_covid_ctscan_data(self, idx): |
|
|
65 |
list_of_bboxes = [] |
|
|
66 |
labels = [] |
|
|
67 |
list_of_masks = [] |
|
|
68 |
# load bbox |
|
|
69 |
self.fname = os.path.join(self.gt, self.sorted_gt[idx]) |
|
|
70 |
# extract bboxes from the mask |
|
|
71 |
mask = np.array(PILImage.open(self.fname)) |
|
|
72 |
# only GGO: merge C and background |
|
|
73 |
# or merge GGO and C into a single mask |
|
|
74 |
# or keep separate masks |
|
|
75 |
if self.mask_type == "ggo": |
|
|
76 |
mask[mask==3] = 0 |
|
|
77 |
elif self.mask_type == "merge": |
|
|
78 |
mask[mask==3] = 2 |
|
|
79 |
# array (NUM_CLASS_IN_IMNG, H,W) without bgr+lungs class (merge Class 0 and 1) |
|
|
80 |
# THIS IS IMPORTANT! CAN TRIGGER CUDA ERROR |
|
|
81 |
mask_classes = mask == np.unique(mask)[:, None, None][2:] |
|
|
82 |
# extract bounding boxes and masks for each object |
|
|
83 |
for _idx, m in enumerate(mask_classes): |
|
|
84 |
lab_mask = method_label(m) |
|
|
85 |
regions = regionprops(lab_mask) |
|
|
86 |
for _i, r in enumerate(regions): |
|
|
87 |
# get rid of really small ones: |
|
|
88 |
if r.area > self.area_th: |
|
|
89 |
box_coords = (r.bbox[1], r.bbox[0], r.bbox[3], r.bbox[2]) |
|
|
90 |
list_of_bboxes.append(box_coords) |
|
|
91 |
labels.append(_idx + 1) |
|
|
92 |
# create a mask for one object, append to the list of masks |
|
|
93 |
mask_obj = self.extract_single_mask(lab_mask, r.label) |
|
|
94 |
list_of_masks.append(mask_obj) |
|
|
95 |
# create labels for Mask R-CNN |
|
|
96 |
# DO NOT CHANGE THESE DATATYPES! |
|
|
97 |
lab = {} |
|
|
98 |
list_of_bboxes = torch.as_tensor(list_of_bboxes, dtype=torch.float) |
|
|
99 |
labels = torch.tensor(labels, dtype=torch.int64) |
|
|
100 |
masks = torch.tensor(list_of_masks, dtype=torch.uint8) |
|
|
101 |
lab['labels'] = labels |
|
|
102 |
lab['boxes'] = list_of_bboxes |
|
|
103 |
lab['masks'] = masks |
|
|
104 |
lab['fname'] = self.fname |
|
|
105 |
lab['img_name'] = self.img_fname |
|
|
106 |
return lab |
|
|
107 |
|
|
|
108 |
# 'magic' method: size of the dataset |
|
|
109 |
def __len__(self): |
|
|
110 |
return len(os.listdir(self.data)) |
|
|
111 |
|
|
|
112 |
# return one datapoint |
|
|
113 |
def __getitem__(self, idx): |
|
|
114 |
X = self.load_img(idx) |
|
|
115 |
y = self.load_labels_covid_ctscan_data(idx) |
|
|
116 |
return X, y |