Diff of /features/slic.py [000000] .. [f77492]

Switch to unified view

a b/features/slic.py
1
import math
2
from skimage import io, color
3
import numpy as np
4
from tqdm import trange
5
6
7
class Cluster(object):
8
    cluster_index = 1
9
10
    def __init__(self, h, w, l=0, a=0, b=0):
11
        self.update(h, w, l, a, b)
12
        self.pixels = []
13
        self.no = self.cluster_index
14
        Cluster.cluster_index += 1
15
16
    def update(self, h, w, l, a, b):
17
        self.h = h
18
        self.w = w
19
        self.l = l
20
        self.a = a
21
        self.b = b
22
23
    def __str__(self):
24
        return "{},{}:{} {} {} ".format(self.h, self.w, self.l, self.a, self.b)
25
26
    def __repr__(self):
27
        return self.__str__()
28
29
30
class SLICProcessor(object):
31
    @staticmethod
32
    def open_image(path):
33
        """
34
        Return:
35
            3D array, row col [LAB]
36
        """
37
        rgb = io.imread(path)
38
        lab_arr = color.rgb2lab(rgb)
39
        print(lab_arr.shape)
40
        return lab_arr
41
42
    @staticmethod
43
    def save_lab_image(path, lab_arr):
44
        """
45
        Convert the array to RBG, then save the image
46
        :param path:
47
        :param lab_arr:
48
        :return:
49
        """
50
        rgb_arr = color.lab2rgb(lab_arr)
51
        io.imsave(path, rgb_arr)
52
53
    def make_cluster(self, h, w):
54
        return Cluster(h, w,
55
                       self.data[h][w][0],
56
                       self.data[h][w][1],
57
                       self.data[h][w][2])
58
59
    def __init__(self, img, K, M):
60
        self.K = K
61
        self.M = M
62
63
        self.data = img
64
        self.image_height = self.data.shape[0]
65
        self.image_width = self.data.shape[1]
66
        self.N = self.image_height * self.image_width
67
        self.S = int(math.sqrt(self.N / self.K))
68
69
        self.clusters = []
70
        self.label = {}
71
        self.dis = np.full((self.image_height, self.image_width), np.inf)
72
73
    def init_clusters(self):
74
        h = self.S // 2
75
        w = self.S // 2
76
        while h < self.image_height:
77
            while w < self.image_width:
78
                self.clusters.append(self.make_cluster(h, w))
79
                w += self.S
80
            w = self.S // 2
81
            h += self.S
82
83
    def get_gradient(self, h, w):
84
        if w + 1 >= self.image_width:
85
            w = self.image_width - 2
86
        if h + 1 >= self.image_height:
87
            h = self.image_height - 2
88
89
        gradient = self.data[w + 1][h + 1][0] - self.data[w][h][0] + \
90
                   self.data[w + 1][h + 1][1] - self.data[w][h][1] + \
91
                   self.data[w + 1][h + 1][2] - self.data[w][h][2]
92
        return gradient
93
94
    def move_clusters(self):
95
        for cluster in self.clusters:
96
            cluster_gradient = self.get_gradient(cluster.h, cluster.w)
97
            for dh in range(-1, 2):
98
                for dw in range(-1, 2):
99
                    _h = cluster.h + dh
100
                    _w = cluster.w + dw
101
                    new_gradient = self.get_gradient(_h, _w)
102
                    if new_gradient < cluster_gradient:
103
                        cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2])
104
                        cluster_gradient = new_gradient
105
106
    def assignment(self):
107
        for cluster in self.clusters:
108
            for h in range(cluster.h - 2 * self.S, cluster.h + 2 * self.S):
109
                if h < 0 or h >= self.image_height: continue
110
                for w in range(cluster.w - 2 * self.S, cluster.w + 2 * self.S):
111
                    if w < 0 or w >= self.image_width: continue
112
                    L, A, B = self.data[h][w]
113
                    Dc = math.sqrt(
114
                        math.pow(L - cluster.l, 2) +
115
                        math.pow(A - cluster.a, 2) +
116
                        math.pow(B - cluster.b, 2))
117
                    Ds = math.sqrt(
118
                        math.pow(h - cluster.h, 2) +
119
                        math.pow(w - cluster.w, 2))
120
                    D = math.sqrt(math.pow(Dc // self.M, 2) + math.pow(Ds // self.S, 2))
121
                    if D < self.dis[h][w]:
122
                        if (h, w) not in self.label:
123
                            self.label[(h, w)] = cluster
124
                            cluster.pixels.append((h, w))
125
                        else:
126
                            self.label[(h, w)].pixels.remove((h, w))
127
                            self.label[(h, w)] = cluster
128
                            cluster.pixels.append((h, w))
129
                        self.dis[h][w] = D
130
131
    def update_cluster(self):
132
        for cluster in self.clusters:
133
            sum_h = sum_w = number = 0
134
            for p in cluster.pixels:
135
                sum_h += p[0]
136
                sum_w += p[1]
137
                number += 1
138
                _h = sum_h // number
139
                _w = sum_w // number
140
                cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2])
141
142
    def save_current_image(self, name):
143
        image_arr = np.copy(self.data)
144
        for cluster in self.clusters:
145
            for p in cluster.pixels:
146
                image_arr[p[0]][p[1]][0] = cluster.l
147
                image_arr[p[0]][p[1]][1] = cluster.a
148
                image_arr[p[0]][p[1]][2] = cluster.b
149
            '''
150
            image_arr[cluster.h][cluster.w][0] = 0
151
            image_arr[cluster.h][cluster.w][1] = 0
152
            image_arr[cluster.h][cluster.w][2] = 0
153
            '''
154
155
        # self.save_lab_image(name, image_arr)
156
        return image_arr
157
158
159
    def iterate_10times(self, filename = 'temp'):
160
        self.init_clusters()
161
        self.move_clusters()
162
        temp_img = []
163
        for i in range(5):
164
            self.assignment()
165
            self.update_cluster()
166
            name = filename[:len(filename) - 4] + 'M{m}_K{k}_loop{loop}.png'.format(loop=i, m=self.M, k=self.K)
167
            temp_img.append(self.save_current_image(name))
168
        # print(temp_img[-1].shape)
169
        return temp_img[-1]
170
171
172
173
# if __name__ == '__main__':
174
#     filename = 'LIDC-IDRI-0087_3000640_1.jpeg'
175
    # p = SLICProcessor(filename, 2048, 3)
176
    # p.iterate_10times(filename)
177
    # p = SLICProcessor(filename, 80, 10)
178
    # p.iterate_10times(filename)
179
    # p = SLICProcessor(filename, 500, 40)
180
    # p.iterate_10times()
181
    # p = SLICProcessor(filename, 1024, 3)
182
    # p.iterate_10times(filename)
183
    # p = SLICProcessor(filename, 200, 5)
184
    # p.iterate_10times()
185
    # p = SLICProcessor(filename, 300, 5)
186
    # p.iterate_10times()
187
    # p = SLICProcessor(filename, 500, 5)
188
    # p.iterate_10times()
189
190
    # import os
191
    
192
    # path = 'res/'
193
    # filelist = os.listdir(path)
194
195
    # for onefile in filelist:
196
    #     filename = path + onefile
197
    #     p = SLICProcessor(filename, 1000, 5)
198
    #     p.iterate_10times(filename)