|
a |
|
b/features/slic.py |
|
|
1 |
import math |
|
|
2 |
from skimage import io, color |
|
|
3 |
import numpy as np |
|
|
4 |
from tqdm import trange |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
class Cluster(object): |
|
|
8 |
cluster_index = 1 |
|
|
9 |
|
|
|
10 |
def __init__(self, h, w, l=0, a=0, b=0): |
|
|
11 |
self.update(h, w, l, a, b) |
|
|
12 |
self.pixels = [] |
|
|
13 |
self.no = self.cluster_index |
|
|
14 |
Cluster.cluster_index += 1 |
|
|
15 |
|
|
|
16 |
def update(self, h, w, l, a, b): |
|
|
17 |
self.h = h |
|
|
18 |
self.w = w |
|
|
19 |
self.l = l |
|
|
20 |
self.a = a |
|
|
21 |
self.b = b |
|
|
22 |
|
|
|
23 |
def __str__(self): |
|
|
24 |
return "{},{}:{} {} {} ".format(self.h, self.w, self.l, self.a, self.b) |
|
|
25 |
|
|
|
26 |
def __repr__(self): |
|
|
27 |
return self.__str__() |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
class SLICProcessor(object): |
|
|
31 |
@staticmethod |
|
|
32 |
def open_image(path): |
|
|
33 |
""" |
|
|
34 |
Return: |
|
|
35 |
3D array, row col [LAB] |
|
|
36 |
""" |
|
|
37 |
rgb = io.imread(path) |
|
|
38 |
lab_arr = color.rgb2lab(rgb) |
|
|
39 |
print(lab_arr.shape) |
|
|
40 |
return lab_arr |
|
|
41 |
|
|
|
42 |
@staticmethod |
|
|
43 |
def save_lab_image(path, lab_arr): |
|
|
44 |
""" |
|
|
45 |
Convert the array to RBG, then save the image |
|
|
46 |
:param path: |
|
|
47 |
:param lab_arr: |
|
|
48 |
:return: |
|
|
49 |
""" |
|
|
50 |
rgb_arr = color.lab2rgb(lab_arr) |
|
|
51 |
io.imsave(path, rgb_arr) |
|
|
52 |
|
|
|
53 |
def make_cluster(self, h, w): |
|
|
54 |
return Cluster(h, w, |
|
|
55 |
self.data[h][w][0], |
|
|
56 |
self.data[h][w][1], |
|
|
57 |
self.data[h][w][2]) |
|
|
58 |
|
|
|
59 |
def __init__(self, img, K, M): |
|
|
60 |
self.K = K |
|
|
61 |
self.M = M |
|
|
62 |
|
|
|
63 |
self.data = img |
|
|
64 |
self.image_height = self.data.shape[0] |
|
|
65 |
self.image_width = self.data.shape[1] |
|
|
66 |
self.N = self.image_height * self.image_width |
|
|
67 |
self.S = int(math.sqrt(self.N / self.K)) |
|
|
68 |
|
|
|
69 |
self.clusters = [] |
|
|
70 |
self.label = {} |
|
|
71 |
self.dis = np.full((self.image_height, self.image_width), np.inf) |
|
|
72 |
|
|
|
73 |
def init_clusters(self): |
|
|
74 |
h = self.S // 2 |
|
|
75 |
w = self.S // 2 |
|
|
76 |
while h < self.image_height: |
|
|
77 |
while w < self.image_width: |
|
|
78 |
self.clusters.append(self.make_cluster(h, w)) |
|
|
79 |
w += self.S |
|
|
80 |
w = self.S // 2 |
|
|
81 |
h += self.S |
|
|
82 |
|
|
|
83 |
def get_gradient(self, h, w): |
|
|
84 |
if w + 1 >= self.image_width: |
|
|
85 |
w = self.image_width - 2 |
|
|
86 |
if h + 1 >= self.image_height: |
|
|
87 |
h = self.image_height - 2 |
|
|
88 |
|
|
|
89 |
gradient = self.data[w + 1][h + 1][0] - self.data[w][h][0] + \ |
|
|
90 |
self.data[w + 1][h + 1][1] - self.data[w][h][1] + \ |
|
|
91 |
self.data[w + 1][h + 1][2] - self.data[w][h][2] |
|
|
92 |
return gradient |
|
|
93 |
|
|
|
94 |
def move_clusters(self): |
|
|
95 |
for cluster in self.clusters: |
|
|
96 |
cluster_gradient = self.get_gradient(cluster.h, cluster.w) |
|
|
97 |
for dh in range(-1, 2): |
|
|
98 |
for dw in range(-1, 2): |
|
|
99 |
_h = cluster.h + dh |
|
|
100 |
_w = cluster.w + dw |
|
|
101 |
new_gradient = self.get_gradient(_h, _w) |
|
|
102 |
if new_gradient < cluster_gradient: |
|
|
103 |
cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2]) |
|
|
104 |
cluster_gradient = new_gradient |
|
|
105 |
|
|
|
106 |
def assignment(self): |
|
|
107 |
for cluster in self.clusters: |
|
|
108 |
for h in range(cluster.h - 2 * self.S, cluster.h + 2 * self.S): |
|
|
109 |
if h < 0 or h >= self.image_height: continue |
|
|
110 |
for w in range(cluster.w - 2 * self.S, cluster.w + 2 * self.S): |
|
|
111 |
if w < 0 or w >= self.image_width: continue |
|
|
112 |
L, A, B = self.data[h][w] |
|
|
113 |
Dc = math.sqrt( |
|
|
114 |
math.pow(L - cluster.l, 2) + |
|
|
115 |
math.pow(A - cluster.a, 2) + |
|
|
116 |
math.pow(B - cluster.b, 2)) |
|
|
117 |
Ds = math.sqrt( |
|
|
118 |
math.pow(h - cluster.h, 2) + |
|
|
119 |
math.pow(w - cluster.w, 2)) |
|
|
120 |
D = math.sqrt(math.pow(Dc // self.M, 2) + math.pow(Ds // self.S, 2)) |
|
|
121 |
if D < self.dis[h][w]: |
|
|
122 |
if (h, w) not in self.label: |
|
|
123 |
self.label[(h, w)] = cluster |
|
|
124 |
cluster.pixels.append((h, w)) |
|
|
125 |
else: |
|
|
126 |
self.label[(h, w)].pixels.remove((h, w)) |
|
|
127 |
self.label[(h, w)] = cluster |
|
|
128 |
cluster.pixels.append((h, w)) |
|
|
129 |
self.dis[h][w] = D |
|
|
130 |
|
|
|
131 |
def update_cluster(self): |
|
|
132 |
for cluster in self.clusters: |
|
|
133 |
sum_h = sum_w = number = 0 |
|
|
134 |
for p in cluster.pixels: |
|
|
135 |
sum_h += p[0] |
|
|
136 |
sum_w += p[1] |
|
|
137 |
number += 1 |
|
|
138 |
_h = sum_h // number |
|
|
139 |
_w = sum_w // number |
|
|
140 |
cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2]) |
|
|
141 |
|
|
|
142 |
def save_current_image(self, name): |
|
|
143 |
image_arr = np.copy(self.data) |
|
|
144 |
for cluster in self.clusters: |
|
|
145 |
for p in cluster.pixels: |
|
|
146 |
image_arr[p[0]][p[1]][0] = cluster.l |
|
|
147 |
image_arr[p[0]][p[1]][1] = cluster.a |
|
|
148 |
image_arr[p[0]][p[1]][2] = cluster.b |
|
|
149 |
''' |
|
|
150 |
image_arr[cluster.h][cluster.w][0] = 0 |
|
|
151 |
image_arr[cluster.h][cluster.w][1] = 0 |
|
|
152 |
image_arr[cluster.h][cluster.w][2] = 0 |
|
|
153 |
''' |
|
|
154 |
|
|
|
155 |
# self.save_lab_image(name, image_arr) |
|
|
156 |
return image_arr |
|
|
157 |
|
|
|
158 |
|
|
|
159 |
def iterate_10times(self, filename = 'temp'): |
|
|
160 |
self.init_clusters() |
|
|
161 |
self.move_clusters() |
|
|
162 |
temp_img = [] |
|
|
163 |
for i in range(5): |
|
|
164 |
self.assignment() |
|
|
165 |
self.update_cluster() |
|
|
166 |
name = filename[:len(filename) - 4] + 'M{m}_K{k}_loop{loop}.png'.format(loop=i, m=self.M, k=self.K) |
|
|
167 |
temp_img.append(self.save_current_image(name)) |
|
|
168 |
# print(temp_img[-1].shape) |
|
|
169 |
return temp_img[-1] |
|
|
170 |
|
|
|
171 |
|
|
|
172 |
|
|
|
173 |
# if __name__ == '__main__': |
|
|
174 |
# filename = 'LIDC-IDRI-0087_3000640_1.jpeg' |
|
|
175 |
# p = SLICProcessor(filename, 2048, 3) |
|
|
176 |
# p.iterate_10times(filename) |
|
|
177 |
# p = SLICProcessor(filename, 80, 10) |
|
|
178 |
# p.iterate_10times(filename) |
|
|
179 |
# p = SLICProcessor(filename, 500, 40) |
|
|
180 |
# p.iterate_10times() |
|
|
181 |
# p = SLICProcessor(filename, 1024, 3) |
|
|
182 |
# p.iterate_10times(filename) |
|
|
183 |
# p = SLICProcessor(filename, 200, 5) |
|
|
184 |
# p.iterate_10times() |
|
|
185 |
# p = SLICProcessor(filename, 300, 5) |
|
|
186 |
# p.iterate_10times() |
|
|
187 |
# p = SLICProcessor(filename, 500, 5) |
|
|
188 |
# p.iterate_10times() |
|
|
189 |
|
|
|
190 |
# import os |
|
|
191 |
|
|
|
192 |
# path = 'res/' |
|
|
193 |
# filelist = os.listdir(path) |
|
|
194 |
|
|
|
195 |
# for onefile in filelist: |
|
|
196 |
# filename = path + onefile |
|
|
197 |
# p = SLICProcessor(filename, 1000, 5) |
|
|
198 |
# p.iterate_10times(filename) |