Switch to unified view

a b/datasets/dataset_segmentation.py
1
import os
2
import re
3
import numpy as np
4
from PIL import Image as PILImage
5
import torch
6
import torch.nn.functional as F
7
from torch.utils import data as data
8
from torchvision import transforms as transforms
9
import matplotlib.pyplot as plt
10
from skimage.measure import label as method_label
11
from skimage.measure import regionprops
12
13
14
# dataset for GGO and C segmentation
15
class CovidCTData(data.Dataset):
16
17
    def __init__(self, **kwargs):
18
        self.mask_type = kwargs['mask_type']
19
        self.ignore_ = kwargs['ignore_small']
20
        # ignore small areas?
21
        if self.ignore_:
22
           self.area_th = 100
23
        else:
24
           self.area_th = 1
25
        self.stage = kwargs['stage']
26
        # this returns the path to imgs dir
27
        self.data = kwargs['data']
28
        # this returns the path to
29
        self.gt = kwargs['gt']
30
        # IMPORTANT: the order of images and masks must be the same
31
        self.sorted_data = sorted(os.listdir(self.data))
32
        self.sorted_gt = sorted(os.listdir(self.gt))
33
        self.fname = None
34
        self.img_fname = None
35
36
    # this method normalizes the image and converts it to Pytorch tensor
37
    # Here we use pytorch transforms functionality, and Compose them together,
38
    def transform_img(self, img):
39
        # Faster R-CNN does the normalization
40
        t_ = transforms.Compose([
41
            # transforms.ToPILImage(),
42
            # transforms.Resize(img_size),
43
            transforms.ToTensor(),
44
        ])
45
        img = t_(img)
46
        return img
47
48
    # inputs: box coords (min_row, min_col, max_row, max_col)
49
    # array HxW from whic to extract a single object's mask
50
    # each isolated mask should have a different label, lab>0
51
    # masks are binary uint8 type
52
    def extract_single_mask(self, mask, lab):
53
        _mask = np.zeros(mask.shape, dtype=np.uint8)
54
        area = mask == lab
55
        _mask[area] = 1
56
        return _mask
57
58
    def load_img(self, idx):
59
        im = PILImage.open(os.path.join(self.data, self.sorted_data[idx]))
60
        self.img_fname = os.path.join(self.data, self.sorted_data[idx])
61
        im = self.transform_img(im)
62
        return im
63
64
    def load_labels_covid_ctscan_data(self, idx):
65
        list_of_bboxes = []
66
        labels = []
67
        list_of_masks = []
68
        # load bbox
69
        self.fname = os.path.join(self.gt, self.sorted_gt[idx])
70
        # extract bboxes from the mask
71
        mask = np.array(PILImage.open(self.fname))
72
        # only GGO: merge C and background
73
        # or merge GGO and C into a single mask
74
        # or keep separate masks
75
        if self.mask_type == "ggo":
76
           mask[mask==3] = 0
77
        elif self.mask_type == "merge":
78
           mask[mask==3] = 2
79
        # array  (NUM_CLASS_IN_IMNG, H,W) without bgr+lungs class (merge Class 0 and 1)
80
        # THIS IS IMPORTANT! CAN TRIGGER CUDA ERROR
81
        mask_classes = mask == np.unique(mask)[:, None, None][2:]
82
        # extract bounding boxes and masks for each object
83
        for _idx, m in enumerate(mask_classes):
84
            lab_mask = method_label(m)
85
            regions = regionprops(lab_mask)
86
            for _i, r in enumerate(regions):
87
                # get rid of really small ones:
88
                if r.area > self.area_th:
89
                    box_coords = (r.bbox[1], r.bbox[0], r.bbox[3], r.bbox[2])
90
                    list_of_bboxes.append(box_coords)
91
                    labels.append(_idx + 1)
92
                    # create a mask for one object, append to the list of masks
93
                    mask_obj = self.extract_single_mask(lab_mask, r.label)
94
                    list_of_masks.append(mask_obj)
95
        # create labels for Mask R-CNN
96
        # DO NOT CHANGE THESE DATATYPES!
97
        lab = {}
98
        list_of_bboxes = torch.as_tensor(list_of_bboxes, dtype=torch.float)
99
        labels = torch.tensor(labels, dtype=torch.int64)
100
        masks = torch.tensor(list_of_masks, dtype=torch.uint8)
101
        lab['labels'] = labels
102
        lab['boxes'] = list_of_bboxes
103
        lab['masks'] = masks
104
        lab['fname'] = self.fname
105
        lab['img_name'] = self.img_fname
106
        return lab
107
108
    # 'magic' method: size of the dataset
109
    def __len__(self):
110
        return len(os.listdir(self.data))
111
112
    # return one datapoint
113
    def __getitem__(self, idx):
114
        X = self.load_img(idx)
115
        y = self.load_labels_covid_ctscan_data(idx)
116
        return X, y