a b/FastRCNN/utils/DefectDataset.py
1
import numpy as np
2
import os
3
from skimage import exposure, filters
4
import chainer
5
from chainercv import utils
6
from chainercv import transforms
7
import warnings
8
9
root = './Data/'
10
root2 = './Data/'
11
12
class DetectionDataset(chainer.dataset.DatasetMixin):
13
    """Base class for defect defection dataset
14
    """
15
16
    def __init__(self, data_dir='auto', split='', img_size=1024, resize=False):
17
        if data_dir == 'auto':
18
            data_dir = root
19
        self.data_dir = data_dir
20
        self.img_size = img_size
21
        self.resize = resize
22
        images_file = os.path.join(self.data_dir, '{}images.txt'.format(split))
23
        self.images = [line.strip() for line in open(images_file)]
24
    
25
    def obtain_image_name(self,i):
26
        #print(i)
27
        return self.images[i]
28
29
    def __len__(self):
30
        return len(self.images)
31
32
    def get_example(self, i):
33
        """Returns the i-th example.
34
35
        Args:
36
            i (int): The index of the example.
37
38
        Returns:
39
            tuple of an image and its label.
40
            The image is in CHW format and its color channel is ordered in
41
            RGB.
42
            a bounding box is appended to the returned value.
43
        """
44
        img = utils.read_image(
45
            os.path.join(self.data_dir, 'images', self.images[i]),
46
            color=True)
47
        # Add processing to the other two channels
48
        with warnings.catch_warnings():
49
            warnings.simplefilter("ignore")
50
            img[1, :, :] = exposure.rescale_intensity(exposure.equalize_adapthist(
51
                exposure.rescale_intensity(img[1, :, :])), out_range=(0, 255))
52
            img[2, :, :] = exposure.rescale_intensity(filters.gaussian(
53
                exposure.rescale_intensity(img[2, :, :])), out_range=(0, 255))
54
55
        # bbs should be a matrix (m by 4). m is the number of bounding
56
        # boxes in the image
57
        # labels should be an integer array (m by 1). m is the same as the bbs
58
59
        bbs_file = os.path.join(self.data_dir, 'bounding_boxes', self.images[i][0:-4]+'.txt')
60
        
61
        bbs = np.stack([line.strip().split() for line in open(bbs_file)]).astype(np.float32)
62
        label = np.stack([0]*bbs.shape[0]).astype(np.int32)
63
64
        _, H, W = img.shape
65
        if self.resize and (H != self.img_size or W != self.img_size):
66
            img = transforms.resize(img, (self.img_size, self.img_size))
67
            bbs = transforms.resize_bbox(bbs, (H, W), (self.img_size, self.img_size))
68
69
        return img, bbs, label
70
71
class MultiDetectionDataset(chainer.dataset.DatasetMixin):
72
    """Base class for multi defect defection dataset
73
    """
74
75
    def __init__(self, data_dir='auto', split='', img_size=1024, resize=False):
76
        if data_dir == 'auto':
77
            data_dir = root2
78
        self.data_dir = data_dir
79
        self.img_size = img_size
80
        self.resize = resize
81
        images_file = os.path.join(self.data_dir, '{}images.txt'.format(split))
82
        self.images = [
83
            line.strip() for line in open(images_file)]
84
85
    def __len__(self):
86
        return len(self.images)
87
88
    def get_example(self, i):
89
        """Returns the i-th example.
90
91
        Args:
92
            i (int): The index of the example.
93
94
        Returns:
95
            tuple of an image and its label.
96
            The image is in CHW format and its color channel is ordered in
97
            RGB.
98
            a bounding box is appended to the returned value.
99
        """
100
        img = utils.read_image(
101
            os.path.join(self.data_dir, 'images', self.images[i]),
102
            color=True)
103
        # Add processing to the other two channels
104
        with warnings.catch_warnings():
105
            warnings.simplefilter("ignore")
106
            img[1, :, :] = exposure.rescale_intensity(exposure.equalize_adapthist(
107
                exposure.rescale_intensity(img[1, :, :])), out_range=(0, 255))
108
            img[2, :, :] = exposure.rescale_intensity(filters.gaussian(
109
                exposure.rescale_intensity(img[2, :, :])), out_range=(0, 255))
110
111
        # bbs should be a matrix (m by 4). m is the number of bounding
112
        # boxes in the image
113
        # labels should be an integer array (m by 1). m is the same as the bbs
114
115
        bbs_file = os.path.join(self.data_dir, 'bounding_boxes', self.images[i][0:-4]+'.txt')
116
        print(bbs_file)
117
        #bbs_file = "./Data/bounding_boxes/a.txt"
118
        label_bbs = np.loadtxt(bbs_file, dtype=np.float32)
119
        print(label_bbs)
120
        # only 1D for medical case
121
        #label = label_bbs[:,0].astype(np.int32)
122
        #bbs = label_bbs[:,1:5]
123
        label = label_bbs[0].astype(np.int32)
124
        bbs = label_bbs[1:5]
125
126
        _, H, W = img.shape
127
        if self.resize and (H != self.img_size or W != self.img_size):
128
            img = transforms.resize(img, (self.img_size, self.img_size))
129
            bbs = transforms.resize_bbox(bbs, (H, W), (self.img_size, self.img_size))
130
131
        return img, bbs, label