Switch to side-by-side view

--- a
+++ b/datasets/dataset_segmentation.py
@@ -0,0 +1,116 @@
+import os
+import re
+import numpy as np
+from PIL import Image as PILImage
+import torch
+import torch.nn.functional as F
+from torch.utils import data as data
+from torchvision import transforms as transforms
+import matplotlib.pyplot as plt
+from skimage.measure import label as method_label
+from skimage.measure import regionprops
+
+
+# dataset for GGO and C segmentation
+class CovidCTData(data.Dataset):
+
+    def __init__(self, **kwargs):
+        self.mask_type = kwargs['mask_type']
+        self.ignore_ = kwargs['ignore_small']
+        # ignore small areas?
+        if self.ignore_:
+           self.area_th = 100
+        else:
+           self.area_th = 1
+        self.stage = kwargs['stage']
+        # this returns the path to imgs dir
+        self.data = kwargs['data']
+        # this returns the path to
+        self.gt = kwargs['gt']
+        # IMPORTANT: the order of images and masks must be the same
+        self.sorted_data = sorted(os.listdir(self.data))
+        self.sorted_gt = sorted(os.listdir(self.gt))
+        self.fname = None
+        self.img_fname = None
+
+    # this method normalizes the image and converts it to Pytorch tensor
+    # Here we use pytorch transforms functionality, and Compose them together,
+    def transform_img(self, img):
+        # Faster R-CNN does the normalization
+        t_ = transforms.Compose([
+            # transforms.ToPILImage(),
+            # transforms.Resize(img_size),
+            transforms.ToTensor(),
+        ])
+        img = t_(img)
+        return img
+
+    # inputs: box coords (min_row, min_col, max_row, max_col)
+    # array HxW from whic to extract a single object's mask
+    # each isolated mask should have a different label, lab>0
+    # masks are binary uint8 type
+    def extract_single_mask(self, mask, lab):
+        _mask = np.zeros(mask.shape, dtype=np.uint8)
+        area = mask == lab
+        _mask[area] = 1
+        return _mask
+
+    def load_img(self, idx):
+        im = PILImage.open(os.path.join(self.data, self.sorted_data[idx]))
+        self.img_fname = os.path.join(self.data, self.sorted_data[idx])
+        im = self.transform_img(im)
+        return im
+
+    def load_labels_covid_ctscan_data(self, idx):
+        list_of_bboxes = []
+        labels = []
+        list_of_masks = []
+        # load bbox
+        self.fname = os.path.join(self.gt, self.sorted_gt[idx])
+        # extract bboxes from the mask
+        mask = np.array(PILImage.open(self.fname))
+        # only GGO: merge C and background
+        # or merge GGO and C into a single mask
+        # or keep separate masks
+        if self.mask_type == "ggo":
+           mask[mask==3] = 0
+        elif self.mask_type == "merge":
+           mask[mask==3] = 2
+        # array  (NUM_CLASS_IN_IMNG, H,W) without bgr+lungs class (merge Class 0 and 1)
+        # THIS IS IMPORTANT! CAN TRIGGER CUDA ERROR
+        mask_classes = mask == np.unique(mask)[:, None, None][2:]
+        # extract bounding boxes and masks for each object
+        for _idx, m in enumerate(mask_classes):
+            lab_mask = method_label(m)
+            regions = regionprops(lab_mask)
+            for _i, r in enumerate(regions):
+                # get rid of really small ones:
+                if r.area > self.area_th:
+                    box_coords = (r.bbox[1], r.bbox[0], r.bbox[3], r.bbox[2])
+                    list_of_bboxes.append(box_coords)
+                    labels.append(_idx + 1)
+                    # create a mask for one object, append to the list of masks
+                    mask_obj = self.extract_single_mask(lab_mask, r.label)
+                    list_of_masks.append(mask_obj)
+        # create labels for Mask R-CNN
+        # DO NOT CHANGE THESE DATATYPES!
+        lab = {}
+        list_of_bboxes = torch.as_tensor(list_of_bboxes, dtype=torch.float)
+        labels = torch.tensor(labels, dtype=torch.int64)
+        masks = torch.tensor(list_of_masks, dtype=torch.uint8)
+        lab['labels'] = labels
+        lab['boxes'] = list_of_bboxes
+        lab['masks'] = masks
+        lab['fname'] = self.fname
+        lab['img_name'] = self.img_fname
+        return lab
+
+    # 'magic' method: size of the dataset
+    def __len__(self):
+        return len(os.listdir(self.data))
+
+    # return one datapoint
+    def __getitem__(self, idx):
+        X = self.load_img(idx)
+        y = self.load_labels_covid_ctscan_data(idx)
+        return X, y