Diff of /features/otsu.py [000000] .. [f77492]

Switch to unified view

a b/features/otsu.py
1
from __future__ import division
2
import math
3
import numpy as np
4
import cv2
5
# import and use one of 3 libraries PIL, cv2, or scipy in that order
6
USE_PIL = True
7
USE_CV2 = False
8
USE_SCIPY = False
9
try:
10
    import PIL
11
    from PIL import Image
12
    raise ImportError
13
except ImportError:
14
    USE_PIL = False
15
if not USE_PIL:
16
    USE_CV2 = True
17
    try:
18
        import cv2
19
    except ImportError:
20
        USE_CV2 = False
21
if not USE_PIL and not USE_CV2:
22
    USE_SCIPY = True
23
    try:
24
        import scipy
25
        from scipy import misc
26
    except ImportError:
27
        USE_SCIPY = False
28
        raise RuntimeError("couldn't load ANY image library")
29
30
class _OtsuPyramid(object):
31
    """segments histogram into pyramid of histograms, each histogram
32
    half the size of the previous. Also generate omega and mu values
33
    for each histogram in the pyramid.
34
    """
35
36
    def load_image(self, im, bins=256):
37
        """ bins is number of intensity levels """
38
        if not type(im) == np.ndarray:
39
            raise ValueError(
40
                'must be passed numpy array. Got ' + str(type(im)) +
41
                ' instead'
42
            )
43
        if im.ndim == 3:
44
            raise ValueError(
45
                'image must be greyscale (and single value per pixel)'
46
            )
47
        self.im = im
48
        hist, ranges = np.histogram(im, bins) #将输入转化为直方图
49
        # print("hist:",hist,"range:",ranges) # hist表示每个区间内的元素个数,ranges表示区间 
50
        # convert the numpy array to list of ints
51
        hist = [int(h) for h in hist]
52
        histPyr, omegaPyr, muPyr, ratioPyr = \
53
            self._create_histogram_and_stats_pyramids(hist)
54
        # arrange so that pyramid[0] is the smallest pyramid
55
        self.omegaPyramid = [omegas for omegas in reversed(omegaPyr)] # 前缀和
56
        #  print("self.omeaPtramid:",len(self.omegaPyramid[0]))
57
        self.muPyramid = [mus for mus in reversed(muPyr)]# 乘了像素的 前缀和
58
        self.ratioPyramid = ratioPyr# [1, 2, 2, 2, 2, 2, 2, 2]
59
        
60
    def _create_histogram_and_stats_pyramids(self, hist):
61
        """Expects hist to be a single list of numbers (no numpy array)
62
        takes an input histogram (with 256 bins) and iteratively
63
        compresses it by a factor of 2 until the last compressed
64
        histogram is of size 2. It stores all these generated histograms
65
        in a list-like pyramid structure. Finally, create corresponding
66
        omega and mu lists for each histogram and return the 3
67
        generated pyramids.
68
        """
69
        bins = len(hist)
70
        # eventually you can replace this with a list if you cannot evenly
71
        # compress a histogram
72
        ratio = 2
73
        reductions = int(math.log(bins, ratio)) #ln(bins)/ln(ratio),约等于开方
74
        compressionFactor = []
75
        histPyramid = []
76
        omegaPyramid = []
77
        muPyramid = []
78
        for _ in range(reductions):
79
            histPyramid.append(hist)
80
            reducedHist = [sum(hist[i:i+ratio]) for i in range(0, bins, ratio)]
81
            # collapse a list to half its size, combining the two collpased
82
            # numbers into one
83
            hist = reducedHist
84
            # update bins to reflect the img_nums of the new histogram
85
            bins = bins // ratio
86
            compressionFactor.append(ratio)
87
        # first "compression" was 1, aka it's the original histogram
88
        compressionFactor[0] = 1
89
        # print("the length of histPyramid:",len(histPyramid))
90
        for hist in histPyramid:
91
            omegas, mus, muT = \
92
                self._calculate_omegas_and_mus_from_histogram(hist)
93
            omegaPyramid.append(omegas)
94
            muPyramid.append(mus)
95
        return histPyramid, omegaPyramid, muPyramid, compressionFactor
96
97
    def _calculate_omegas_and_mus_from_histogram(self, hist):
98
        """ Comput histogram statistical data: omega and mu for each
99
        intensity level in the histogram
100
        """
101
        probabilityLevels, meanLevels = \
102
            self._calculate_histogram_pixel_stats(hist)
103
        bins = len(probabilityLevels)
104
        # these numbers are critical towards calculations, so we make sure
105
        # they are float
106
        ptotal = float(0)
107
        # sum of probability levels up to k
108
        omegas = []
109
        for i in range(bins):
110
            ptotal += probabilityLevels[i]
111
            omegas.append(ptotal)
112
        mtotal = float(0)
113
        mus = []
114
        for i in range(bins):
115
            mtotal += meanLevels[i]
116
            mus.append(mtotal)
117
        # muT is the total mean levels.
118
        muT = float(mtotal)
119
        return omegas, mus, muT
120
121
    def _calculate_histogram_pixel_stats(self, hist):
122
        """Given a histogram, compute pixel probability and mean
123
        levels for each bin in the histogram. Pixel probability
124
        represents the likely-hood that a pixel's intensty resides in
125
        a specific bin. Pixel mean is the intensity-weighted pixel
126
        probability.
127
        """
128
        # bins = number of intensity levels
129
        bins = len(hist)
130
        # print("bins:",bins)
131
        # N = number of pixels in image. Make it float so that division by
132
        # N will be a float
133
        N = float(sum(hist))
134
        # percentage of pixels at each intensity level: i => P_i
135
        hist_probability = [hist[i] / N for i in range(bins)]
136
        # mean level of pixels at intensity level i   => i * P_i
137
        pixel_mean = [i * hist_probability[i] for i in range(bins)]
138
        # print("N:",N)
139
        return hist_probability, pixel_mean
140
141
142
class OtsuFastMultithreshold(_OtsuPyramid):
143
    """Sacrifices precision for speed. OtsuFastMultithreshold can dial
144
    in to the threshold but still has the possibility that its
145
    thresholds will not be the same as a naive-Otsu's method would give
146
    """
147
148
    def calculate_k_thresholds(self, k):
149
        self.threshPyramid = []
150
        start = self._get_smallest_fitting_pyramid(k)
151
        self.bins = len(self.omegaPyramid[start])
152
        # print("self.bins:",self.bins)
153
        thresholds = self._get_first_guess_thresholds(k)
154
        # print("thresholds:",thresholds)
155
        # give hunting algorithm full range so that initial thresholds
156
        # can become any value (0-bins)
157
        deviate = self.bins // 2
158
        for i in range(start, len(self.omegaPyramid)):
159
            omegas = self.omegaPyramid[i] # 个数平均前缀和
160
            mus = self.muPyramid[i] # 平均像素前缀和
161
            hunter = _ThresholdHunter(omegas, mus, deviate)
162
            thresholds = \
163
                hunter.find_best_thresholds_around_estimates(thresholds)
164
            self.threshPyramid.append(thresholds)
165
            # how much our "just analyzed" pyramid was compressed from the
166
            # previous one
167
            scaling = self.ratioPyramid[i]
168
            # deviate should be equal to the compression factor of the
169
            # previous histogram.
170
            deviate = scaling
171
            thresholds = [t * scaling for t in thresholds]
172
        # return readjusted threshold (since it was scaled up incorrectly in
173
        # last loop)
174
        # print("true thresholds:",thresholds)
175
        return [t // scaling for t in thresholds]
176
177
    def _get_smallest_fitting_pyramid(self, k):
178
        """Return the index for the smallest pyramid set that can fit
179
        K thresholds
180
        """
181
        for i, pyramid in enumerate(self.omegaPyramid):
182
            if len(pyramid) >= k:
183
                return i
184
185
    def _get_first_guess_thresholds(self, k):
186
        """Construct first-guess thresholds based on number of
187
        thresholds (k) and constraining intensity values. FirstGuesses
188
        will be centered around middle intensity value.
189
        """
190
        kHalf = k // 2
191
        midway = self.bins // 2
192
        firstGuesses = [midway - i for i in range(kHalf, 0, -1)] + [midway] + \
193
            [midway + i for i in range(1, kHalf)]
194
        # print("firstGuesses:",firstGuesses)
195
        # additional threshold in case k is odd
196
        firstGuesses.append(self.bins - 1)
197
        return firstGuesses[:k]
198
199
    def apply_thresholds_to_image(self, thresholds, im=None):
200
        if im is None:
201
            im = self.im
202
        k = len(thresholds)
203
        bookendedThresholds = [None] + thresholds + [None]
204
        # I think you need to use 255 / k *...
205
        greyValues = [0] + [int(256 / k * (i + 1)) for i in range(0, k - 1)] \
206
            + [255]
207
        greyValues = np.array(greyValues, dtype=np.uint8)
208
        finalImage = np.zeros(im.shape, dtype=np.uint8)
209
        for i in range(k + 1):
210
            kSmall = bookendedThresholds[i]
211
            # True portions of bw represents pixels between the two thresholds
212
            bw = np.ones(im.shape, dtype=np.bool8)
213
            if kSmall:
214
                bw = (im >= kSmall)
215
            kLarge = bookendedThresholds[i + 1]
216
            if kLarge:
217
                bw &= (im < kLarge)
218
            greyLevel = greyValues[i]
219
            # apply grey-color to black-and-white image
220
            greyImage = bw * greyLevel
221
            # add grey portion to image. There should be no overlap between
222
            # each greyImage added
223
            finalImage += greyImage
224
        return finalImage
225
226
227
class _ThresholdHunter(object):
228
    """Hunt/deviate around given thresholds in a small region to look
229
    for a better threshold
230
    """
231
232
    def __init__(self, omegas, mus, deviate=2):
233
        self.sigmaB = _BetweenClassVariance(omegas, mus)
234
        # used to be called L
235
        self.bins = self.sigmaB.bins
236
        # hunt 2 (or other amount) to either side of thresholds
237
        self.deviate = deviate
238
239
    def find_best_thresholds_around_estimates(self, estimatedThresholds):
240
        """Given guesses for best threshold, explore to either side of
241
        the threshold and return the best result.
242
        """
243
        bestResults = (
244
            0, estimatedThresholds, [0 for t in estimatedThresholds]
245
        )
246
        # print("bestResults:",bestResults)
247
        bestThresholds = estimatedThresholds
248
        bestVariance = 0
249
        for thresholds in self._jitter_thresholds_generator(
250
                estimatedThresholds, 0, self.bins):
251
            # print("thresholds:",thresholds)
252
            variance = self.sigmaB.get_total_variance(thresholds)
253
            if variance == bestVariance:
254
                if sum(thresholds) < sum(bestThresholds):
255
                    # keep lowest average set of thresholds
256
                    bestThresholds = thresholds
257
            elif variance > bestVariance:
258
                bestVariance = variance
259
                bestThresholds = thresholds
260
        return bestThresholds
261
262
    def find_best_thresholds_around_estimates_experimental(self, estimatedThresholds):
263
        """Experimental threshold hunting uses scipy optimize method.
264
        Finds ok thresholds but doesn't work quite as well
265
        """
266
        estimatedThresholds = [int(k) for k in estimatedThresholds]
267
        if sum(estimatedThresholds) < 10:
268
            return self.find_best_thresholds_around_estimates(
269
                estimatedThresholds
270
            )
271
        # print('estimated', estimatedThresholds)
272
        fxn_to_minimize = lambda x: -1 * self.sigmaB.get_total_variance(
273
            [int(k) for k in x]
274
        )
275
        bestThresholds = scipy.optimize.fmin(
276
            fxn_to_minimize, estimatedThresholds
277
        )
278
        bestThresholds = [int(k) for k in bestThresholds]
279
        # print('bestTresholds', bestThresholds)
280
        return bestThresholds
281
282
    def _jitter_thresholds_generator(self, thresholds, min_, max_):
283
        pastThresh = thresholds[0]
284
        # print("pastThresd:",pastThresh)
285
        if len(thresholds) == 1:
286
            # -2 through +2
287
            for offset in range(-self.deviate, self.deviate + 1):
288
                thresh = pastThresh + offset
289
                if thresh < min_ or thresh >= max_:
290
                    # skip since we are conflicting with bounds
291
                    continue
292
                yield [thresh]
293
        else:
294
            # new threshold without our threshold included
295
            thresholds = thresholds[1:]
296
            # number of threshold left to generate in chain
297
            m = len(thresholds)
298
            for offset in range(-self.deviate, self.deviate + 1):
299
                thresh = pastThresh + offset
300
                # verify we don't use the same value as the previous threshold
301
                # and also verify our current threshold will not push the last
302
                # threshold past max
303
                if thresh < min_ or thresh + m >= max_:
304
                    continue
305
                recursiveGenerator = self._jitter_thresholds_generator(
306
                    thresholds, thresh + 1, max_
307
                )
308
                for otherThresholds in recursiveGenerator:
309
                    yield [thresh] + otherThresholds
310
311
312
class _BetweenClassVariance(object):
313
314
    def __init__(self, omegas, mus):
315
        self.omegas = omegas
316
        self.mus = mus
317
        # number of bins / luminosity choices
318
        self.bins = len(mus)
319
        self.muTotal = sum(mus)
320
321
    def get_total_variance(self, thresholds):
322
        """Function will pad the thresholds argument with minimum and
323
        maximum thresholds to calculate between class variance
324
        """
325
        thresholds = [0] + thresholds + [self.bins - 1]
326
        numClasses = len(thresholds) - 1
327
        # print("thesholds:",thresholds)
328
        sigma = 0
329
        for i in range(numClasses):
330
            k1 = thresholds[i]
331
            k2 = thresholds[i+1]
332
            sigma += self._between_thresholds_variance(k1, k2)
333
        # print("sigma:",sigma)
334
        return sigma
335
336
    def _between_thresholds_variance(self, k1, k2):
337
        """to be usedin calculating between-class variances only!"""
338
        # print("len(self.omegas)):",len(self.omegas))
339
        # print("k2:",k2, "kl:",k1)
340
        omega = self.omegas[k2] - self.omegas[k1]
341
        mu = self.mus[k2] - self.mus[k1]
342
        muT = self.muTotal
343
        return omega * ((mu - muT)**2)
344
import math 
345
def normalize(img):
346
    # 需要特别注意的是无穷大和无穷小 
347
    Max = -1
348
    Min = 10000
349
    for i in range(img.shape[0]):
350
        for j in range(img.shape[1]):
351
            if img[i][j] > Max :
352
                Max = img[i][j]
353
            if img[i][j] < Min:
354
                Min = img[i][j]
355
    print(Max, Min)
356
    for i in range(img.shape[0]):
357
        for j in range(img.shape[0]):
358
            if math.isnan(img[i][j]):
359
                img[i][j] = 255.
360
    img = (img - Min) / (Max - Min)
361
    return img * 255.
362
363
def _otsu(img, categories_pixel_nums = 1):
364
    '''
365
    ret_num: the number of image return
366
    img_nums: the number of calculate thresholds
367
    categories_pixel_nums: the number of pixel categories
368
    return the images with only categories_pixel_nums types of pixels
369
    '''
370
    return normalize(img.copy())
371
    # img = normalize(img.copy())
372
    # ot = OtsuFastMultithreshold()
373
    # ot.load_image(img)
374
    # kThresholds = ot.calculate_k_thresholds(categories_pixel_nums)
375
    # #  print(total)
376
    # return ot.apply_thresholds_to_image(kThresholds,img)
377
    # plt.imsave('C:/Users/RL/Desktop/可解释性的特征学习分类/otsu/' + str(i) + '.jpg',crushed,cmap="gray")
378
    # ans.append(crushed)
379
    # plt.figure()
380
    # plt.subplot(121)
381
    # plt.imshow(img[i],cmap="gray")
382
    # plt.subplot(122)
383
    # plt.imshow(crushed,cmap="gray")
384
    # plt.show()
385
386
#otsu的接口,返回处理后的图片
387
def otsu_helper(img, upper=0.5, down = -0.5,categories=1):
388
    ot = OtsuFastMultithreshold()
389
    ot.load_image(img)
390
    return ot.apply_thresholds_to_image([down, upper], img)