Diff of /features/slic.py [000000] .. [f77492]

Switch to side-by-side view

--- a
+++ b/features/slic.py
@@ -0,0 +1,198 @@
+import math
+from skimage import io, color
+import numpy as np
+from tqdm import trange
+
+
+class Cluster(object):
+    cluster_index = 1
+
+    def __init__(self, h, w, l=0, a=0, b=0):
+        self.update(h, w, l, a, b)
+        self.pixels = []
+        self.no = self.cluster_index
+        Cluster.cluster_index += 1
+
+    def update(self, h, w, l, a, b):
+        self.h = h
+        self.w = w
+        self.l = l
+        self.a = a
+        self.b = b
+
+    def __str__(self):
+        return "{},{}:{} {} {} ".format(self.h, self.w, self.l, self.a, self.b)
+
+    def __repr__(self):
+        return self.__str__()
+
+
+class SLICProcessor(object):
+    @staticmethod
+    def open_image(path):
+        """
+        Return:
+            3D array, row col [LAB]
+        """
+        rgb = io.imread(path)
+        lab_arr = color.rgb2lab(rgb)
+        print(lab_arr.shape)
+        return lab_arr
+
+    @staticmethod
+    def save_lab_image(path, lab_arr):
+        """
+        Convert the array to RBG, then save the image
+        :param path:
+        :param lab_arr:
+        :return:
+        """
+        rgb_arr = color.lab2rgb(lab_arr)
+        io.imsave(path, rgb_arr)
+
+    def make_cluster(self, h, w):
+        return Cluster(h, w,
+                       self.data[h][w][0],
+                       self.data[h][w][1],
+                       self.data[h][w][2])
+
+    def __init__(self, img, K, M):
+        self.K = K
+        self.M = M
+
+        self.data = img
+        self.image_height = self.data.shape[0]
+        self.image_width = self.data.shape[1]
+        self.N = self.image_height * self.image_width
+        self.S = int(math.sqrt(self.N / self.K))
+
+        self.clusters = []
+        self.label = {}
+        self.dis = np.full((self.image_height, self.image_width), np.inf)
+
+    def init_clusters(self):
+        h = self.S // 2
+        w = self.S // 2
+        while h < self.image_height:
+            while w < self.image_width:
+                self.clusters.append(self.make_cluster(h, w))
+                w += self.S
+            w = self.S // 2
+            h += self.S
+
+    def get_gradient(self, h, w):
+        if w + 1 >= self.image_width:
+            w = self.image_width - 2
+        if h + 1 >= self.image_height:
+            h = self.image_height - 2
+
+        gradient = self.data[w + 1][h + 1][0] - self.data[w][h][0] + \
+                   self.data[w + 1][h + 1][1] - self.data[w][h][1] + \
+                   self.data[w + 1][h + 1][2] - self.data[w][h][2]
+        return gradient
+
+    def move_clusters(self):
+        for cluster in self.clusters:
+            cluster_gradient = self.get_gradient(cluster.h, cluster.w)
+            for dh in range(-1, 2):
+                for dw in range(-1, 2):
+                    _h = cluster.h + dh
+                    _w = cluster.w + dw
+                    new_gradient = self.get_gradient(_h, _w)
+                    if new_gradient < cluster_gradient:
+                        cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2])
+                        cluster_gradient = new_gradient
+
+    def assignment(self):
+        for cluster in self.clusters:
+            for h in range(cluster.h - 2 * self.S, cluster.h + 2 * self.S):
+                if h < 0 or h >= self.image_height: continue
+                for w in range(cluster.w - 2 * self.S, cluster.w + 2 * self.S):
+                    if w < 0 or w >= self.image_width: continue
+                    L, A, B = self.data[h][w]
+                    Dc = math.sqrt(
+                        math.pow(L - cluster.l, 2) +
+                        math.pow(A - cluster.a, 2) +
+                        math.pow(B - cluster.b, 2))
+                    Ds = math.sqrt(
+                        math.pow(h - cluster.h, 2) +
+                        math.pow(w - cluster.w, 2))
+                    D = math.sqrt(math.pow(Dc // self.M, 2) + math.pow(Ds // self.S, 2))
+                    if D < self.dis[h][w]:
+                        if (h, w) not in self.label:
+                            self.label[(h, w)] = cluster
+                            cluster.pixels.append((h, w))
+                        else:
+                            self.label[(h, w)].pixels.remove((h, w))
+                            self.label[(h, w)] = cluster
+                            cluster.pixels.append((h, w))
+                        self.dis[h][w] = D
+
+    def update_cluster(self):
+        for cluster in self.clusters:
+            sum_h = sum_w = number = 0
+            for p in cluster.pixels:
+                sum_h += p[0]
+                sum_w += p[1]
+                number += 1
+                _h = sum_h // number
+                _w = sum_w // number
+                cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2])
+
+    def save_current_image(self, name):
+        image_arr = np.copy(self.data)
+        for cluster in self.clusters:
+            for p in cluster.pixels:
+                image_arr[p[0]][p[1]][0] = cluster.l
+                image_arr[p[0]][p[1]][1] = cluster.a
+                image_arr[p[0]][p[1]][2] = cluster.b
+            '''
+            image_arr[cluster.h][cluster.w][0] = 0
+            image_arr[cluster.h][cluster.w][1] = 0
+            image_arr[cluster.h][cluster.w][2] = 0
+            '''
+
+        # self.save_lab_image(name, image_arr)
+        return image_arr
+
+
+    def iterate_10times(self, filename = 'temp'):
+        self.init_clusters()
+        self.move_clusters()
+        temp_img = []
+        for i in range(5):
+            self.assignment()
+            self.update_cluster()
+            name = filename[:len(filename) - 4] + 'M{m}_K{k}_loop{loop}.png'.format(loop=i, m=self.M, k=self.K)
+            temp_img.append(self.save_current_image(name))
+        # print(temp_img[-1].shape)
+        return temp_img[-1]
+
+
+
+# if __name__ == '__main__':
+#     filename = 'LIDC-IDRI-0087_3000640_1.jpeg'
+    # p = SLICProcessor(filename, 2048, 3)
+    # p.iterate_10times(filename)
+    # p = SLICProcessor(filename, 80, 10)
+    # p.iterate_10times(filename)
+    # p = SLICProcessor(filename, 500, 40)
+    # p.iterate_10times()
+    # p = SLICProcessor(filename, 1024, 3)
+    # p.iterate_10times(filename)
+    # p = SLICProcessor(filename, 200, 5)
+    # p.iterate_10times()
+    # p = SLICProcessor(filename, 300, 5)
+    # p.iterate_10times()
+    # p = SLICProcessor(filename, 500, 5)
+    # p.iterate_10times()
+
+    # import os
+    
+    # path = 'res/'
+    # filelist = os.listdir(path)
+
+    # for onefile in filelist:
+    #     filename = path + onefile
+    #     p = SLICProcessor(filename, 1000, 5)
+    #     p.iterate_10times(filename)