Diff of /slideflow/norm/utils.py [000000] .. [78ef36]

Switch to unified view

a b/slideflow/norm/utils.py
1
"""
2
From https://github.com/wanghao14/Stain_Normalization
3
Uses the spams package:
4
5
http://spams-devel.gforge.inria.fr/index.html
6
7
Use with python via e.g https://anaconda.org/conda-forge/python-spams
8
"""
9
10
from __future__ import division
11
12
import cv2
13
import numpy as np
14
from typing import Union, List, Tuple
15
16
# -----------------------------------------------------------------------------
17
18
# Stain normalizer default fits.
19
# v1 is the fit with target sf.norm.norm_tile.jpg (default in version <1.6)
20
# v2 is a hand-tuned fit
21
# v3 is fit using an average of ~50k tiles across ~450 slides from TCGA (default for versions >=1.6)
22
23
fit_presets = {
24
    'reinhard': {
25
        'v1': {'target_means': np.array([ 72.272896,  22.99831 , -13.860236]),
26
               'target_stds': np.array([15.594496,  9.642087,  9.290526])},
27
        'v2': {'target_means': np.array([72.909996, 20.8268, -4.9465137]),
28
               'target_stds': np.array([18.560713, 14.889295,  5.6756697])},
29
        'v3': {'target_means': np.array([65.22132,  28.934267, -14.142519]),
30
               'target_stds': np.array([15.800227,  9.263783,  6.0213304])}
31
    },
32
    'reinhard_fast': {
33
        'v1': {'target_means': np.array([63.71194 ,  20.716246, -12.290746]),
34
               'target_stds': np.array([14.52781 ,  8.344005,  8.300264])},
35
        'v2': {'target_means': np.array([69.20197, 19.82498, -4.690998]),
36
               'target_stds': np.array([17.71583, 14.156416,  5.4176064])},
37
        'v3': {'target_means': np.array([58.12343,  26.483482, -12.701005]),
38
               'target_stds': np.array([14.675022,  7.5744166,  5.226378])},
39
    },
40
    'macenko': {
41
        'v1': {'stain_matrix_target': np.array([[0.63111544, 0.24816133],
42
                                                [0.6962834 , 0.8226449 ],
43
                                                [0.34188122, 0.5115382 ]]),
44
               'target_concentrations': np.array([1.4423684, 0.9685806])},
45
        'v2': {'stain_matrix_target': np.array([[0.5626, 0.2159],
46
                                                [0.7201, 0.8012],
47
                                                [0.4062, 0.5581]]),
48
               'target_concentrations': np.array([1.9705, 1.0308])},
49
        'v3': {'stain_matrix_target': np.array([[0.5062568, 0.2218694],
50
                                                [0.75322306, 0.8652155],
51
                                                [0.40691733, 0.42241502]]),
52
               'target_concentrations': np.array([1.7656903, 1.2797493])},
53
    },
54
    'macenko_fast': {
55
        'v1': {'stain_matrix_target': np.array([[0.6148019 , 0.21480364],
56
                                                [0.7010872 , 0.82317936],
57
                                                [0.36124164, 0.5255809 ]]),
58
               'target_concentrations': np.array([1.8029537, 0.9606744])},
59
        'v2': {'stain_matrix_target': np.array([[0.5626, 0.2159],
60
                                                [0.7201, 0.8012],
61
                                                [0.4062, 0.5581]]),
62
               'target_concentrations': np.array([1.9705, 1.0308])},
63
        'v3': {'stain_matrix_target': np.array([[0.52000326, 0.2623537 ],
64
                                                [0.73508584, 0.83495414],
65
                                                [0.4249617 , 0.4630997 ]]),
66
               'target_concentrations': np.array([2.0259454, 1.4088874])},
67
    },
68
    'vahadane_sklearn': {
69
        'v1': {'stain_matrix_target': np.array([[0.9840825 , 0.17771211, 0.        ],
70
                                                [0.        , 0.87096226, 0.49134994]])},
71
        'v2': {'stain_matrix_target': np.array([[0.95465684, 0.29770842, 0.        ],
72
                                                [0.        , 0.8053334 , 0.59282213]])},
73
    },
74
    'vahadane_spams': {
75
        'v1': {'stain_matrix_target': np.array([[0.54176575, 0.75441414, 0.37060648],
76
                                                [0.17089975, 0.8640189 , 0.4735658 ]])},
77
        'v2': {'stain_matrix_target': np.array([[0.4435433 , 0.7502863 , 0.4902447 ],
78
                                                [0.27688965, 0.8088818 , 0.5186929 ]])},
79
    }
80
}
81
82
# Stain normalizer default augmentation spaces.
83
# v1 is derived from the standard deviation of fit values for ~50k tiles from ~450 slides in TCGA.
84
85
augment_presets = {
86
    'reinhard': {
87
        'v1': {'means_stdev': np.array([1.1882676, 1.3114343, 1.1200949]) * 5,
88
                'stds_stdev': np.array([0.5123385 , 0.37919158, 0.26019168]) * 5},
89
        'v2': {'means_stdev': np.array([1.1882676, 1.3114343, 1.1200949]) * 3,
90
                'stds_stdev': np.array([0.5123385 , 0.37919158, 0.26019168]) * 3}
91
    },
92
    'reinhard_fast': {
93
        'v1': {'means_stdev': np.array([1.2963034 , 1.0061347 , 0.90867484]) * 5,
94
                'stds_stdev': np.array([0.47548684, 0.3956356 , 0.23499836]) * 5},
95
        'v2': {'means_stdev': np.array([1.2963034 , 1.0061347 , 0.90867484]) * 3,
96
                'stds_stdev': np.array([0.47548684, 0.3956356 , 0.23499836]) * 3},
97
    },
98
    'macenko': {
99
        'v1': {'matrix_stdev': np.array([[0.00893346, 0.01153686],
100
                                         [0.00659814, 0.00722771],
101
                                         [0.00726339, 0.01352414]]) * 5,
102
               'concentrations_stdev': np.array([0.06665898, 0.06770515]) * 5},
103
        'v2': {'matrix_stdev': np.array([[0.00893346, 0.01153686],
104
                                         [0.00659814, 0.00722771],
105
                                         [0.00726339, 0.01352414]]) * 3,
106
               'concentrations_stdev': np.array([0.06665898, 0.06770515]) * 3}
107
    },
108
    'macenko_fast': {
109
        'v1': {'matrix_stdev': np.array([[0.00794701, 0.01137106],
110
                                         [0.00559027, 0.00642623],
111
                                         [0.00609103, 0.01144302]]) * 5,
112
               'concentrations_stdev': np.array([0.06623945, 0.08137263]) * 5},
113
        'v2': {'matrix_stdev': np.array([[0.00794701, 0.01137106],
114
                                         [0.00559027, 0.00642623],
115
                                         [0.00609103, 0.01144302]]) * 3,
116
               'concentrations_stdev': np.array([0.06623945, 0.08137263]) * 3}
117
    }
118
}
119
120
# -----------------------------------------------------------------------------
121
122
illuminants = {
123
    "A": {
124
        "2": (1.098466069456375, 1, 0.3558228003436005),
125
        "10": (1.111420406956693, 1, 0.3519978321919493),
126
    },
127
    "D50": {
128
        "2": (0.9642119944211994, 1, 0.8251882845188288),
129
        "10": (0.9672062750333777, 1, 0.8142801513128616),
130
    },
131
    "D55": {
132
        "2": (0.956797052643698, 1, 0.9214805860173273),
133
        "10": (0.9579665682254781, 1, 0.9092525159847462),
134
    },
135
    "D65": {
136
        "2": (0.95047, 1.0, 1.08883),
137
        "10": (0.94809667673716, 1, 1.0730513595166162),
138
    },
139
    "D75": {
140
        "2": (0.9497220898840717, 1, 1.226393520724154),
141
        "10": (0.9441713925645873, 1, 1.2064272211720228),
142
    },
143
    "E": {"2": (1.0, 1.0, 1.0), "10": (1.0, 1.0, 1.0)},
144
}
145
146
rgb_to_xyz_kernels = {
147
    dtype: np.array(
148
        [
149
            [0.412453, 0.357580, 0.180423],
150
            [0.212671, 0.715160, 0.072169],
151
            [0.019334, 0.119193, 0.950227],
152
        ],
153
        dtype=dtype,
154
    ) for dtype in ('float16', 'float32', 'float64')
155
}
156
157
# inv of:
158
# [[0.412453, 0.35758 , 0.180423],
159
#  [0.212671, 0.71516 , 0.072169],
160
#  [0.019334, 0.119193, 0.950227]]
161
xyz_to_rgb_kernels = {
162
    dtype: np.array(
163
        [
164
            [3.24048134, -1.53715152, -0.49853633],
165
            [-0.96925495, 1.87599, 0.04155593],
166
            [0.05564664, -0.20404134, 1.05731107],
167
        ],
168
        dtype=dtype,
169
    ) for dtype in ('float16', 'float32', 'float64')
170
}
171
172
######################################
173
174
175
def brightness_percentile(I):
176
    return np.percentile(I, 90)
177
178
179
def standardize_brightness(I, mask=False):
180
    """
181
182
    :param I:
183
    :return:
184
    """
185
    if mask:
186
        ones = np.all(I == 255, axis=len(I.shape)-1)
187
    bI = I if not mask else I[~ ones]
188
    p = brightness_percentile(bI)
189
    clipped = np.clip(I * 255.0 / p, 0, 255).astype(np.uint8)
190
    if mask:
191
        clipped[ones] = 255
192
    return clipped
193
194
195
def remove_zeros(I):
196
    """
197
    Remove zeros, replace with 1's.
198
    :param I: uint8 array
199
    :return:
200
    """
201
    mask = (I == 0)
202
    I[mask] = 1
203
    return I
204
205
206
def RGB_to_OD(I):
207
    """
208
    Convert from RGB to optical density
209
    :param I:
210
    :return:
211
    """
212
    I = remove_zeros(I)
213
    return -1 * np.log(I / 255).astype(np.float32)
214
215
216
def OD_to_RGB(OD):
217
    """
218
    Convert from optical density to RGB
219
    :param OD:
220
    :return:
221
    """
222
    return (255 * np.exp(-1 * OD)).astype(np.uint8)
223
224
225
def normalize_rows(A):
226
    """
227
    Normalize rows of an array
228
    :param A:
229
    :return:
230
    """
231
    return A / np.linalg.norm(A, axis=1)[:, None]
232
233
234
def notwhite_mask(I, thresh=0.8):
235
    """
236
    Get a binary mask where true denotes 'not white'
237
    :param I:
238
    :param thresh:
239
    :return:
240
    """
241
    I_LAB = cv2.cvtColor(I, cv2.COLOR_RGB2LAB)
242
    L = I_LAB[:, :, 0] / 255.0
243
    return (L < thresh)
244
245
246
def sign(x):
247
    """
248
    Returns the sign of x
249
    :param x:
250
    :return:
251
    """
252
    if x > 0:
253
        return +1
254
    elif x < 0:
255
        return -1
256
    elif x == 0:
257
        return 0
258
259
260
def get_concentrations(I, stain_matrix, lamda=0.01):
261
    """
262
    Get concentrations, a npix x 2 matrix
263
    :param I:
264
    :param stain_matrix: a 2x3 stain matrix
265
    :return:
266
    """
267
    OD = RGB_to_OD(I).reshape((-1, 3))
268
269
    # rows correspond to channels (RGB), columns to OD values
270
    Y = np.reshape(OD, (-1, 3)).T
271
272
    # determine concentrations of the individual stains
273
    C = np.linalg.lstsq(stain_matrix.T, Y, rcond=None)[0]
274
    return C.T
275
276
277
def clip_size(I, max_size=2048):
278
    # Cap the context size to a maximum of (2048, 2048).
279
    if I.shape[0] > max_size or I.shape[1] > max_size:
280
        w, h = I.shape[0], I.shape[1]
281
        if w > h:
282
            h = int((h / w) * max_size)
283
            w = max_size
284
        else:
285
            w = int((w / h) * max_size)
286
            h = max_size
287
        I = cv2.resize(I, (h, w))
288
    return I
289
290
291
def _as_numpy(arg1: Union[List, np.ndarray]) -> np.ndarray:
292
    """Ensures array is a numpy array."""
293
294
    if isinstance(arg1, list):
295
        return np.squeeze(np.array(arg1)).astype(np.float32)
296
    elif isinstance(arg1, np.ndarray):
297
        return np.squeeze(arg1).astype(np.float32)
298
    else:
299
        raise ValueError(f'Expected numpy array; got {type(arg1)}')
300
301
# =============================================================================
302
303
import numpy as np
304
305
306
def unstack(a, axis = 0):
307
    return [np.squeeze(e, axis) for e in np.split(a, a.shape[axis], axis = axis)]
308
309
310
def rgb_to_xyz(input):
311
    """
312
    Convert a RGB image to CIE XYZ.
313
    Args:
314
      input: A 3-D (`[H, W, 3]`) or 4-D (`[N, H, W, 3]`) Tensor.
315
      name: A name for the operation (optional).
316
    Returns:
317
      A 3-D (`[H, W, 3]`) or 4-D (`[N, H, W, 3]`) Tensor.
318
    """
319
    assert input.dtype in (np.float16, np.float32, np.float64)
320
321
    kernel = rgb_to_xyz_kernels[str(input.dtype)]
322
    value = np.where(
323
        input > 0.04045,
324
        np.power((input + 0.055) / 1.055, 2.4),
325
        input / 12.92,
326
    )
327
    return np.tensordot(value, np.transpose(kernel), axes=((-1,), (0,)))
328
329
330
def xyz_to_rgb(input):
331
    """
332
    Convert a CIE XYZ image to RGB.
333
    Args:
334
      input: A 3-D (`[H, W, 3]`) or 4-D (`[N, H, W, 3]`) Tensor.
335
      name: A name for the operation (optional).
336
    Returns:
337
      A 3-D (`[H, W, 3]`) or 4-D (`[N, H, W, 3]`) Tensor.
338
    """
339
    assert input.dtype in (np.float16, np.float32, np.float64)
340
341
    kernel = xyz_to_rgb_kernels[str(input.dtype)]
342
    value = np.tensordot(input, np.transpose(kernel), axes=((-1,), (0,)))
343
    value = np.where(
344
        value > 0.0031308,
345
        np.power(np.clip(value, 0, None), 1.0 / 2.4) * 1.055 - 0.055,
346
        value * 12.92,
347
    )
348
    return np.clip(value, 0, 1)
349
350
351
def lab_to_rgb(input, illuminant="D65", observer="2"):
352
    """
353
    Convert a CIE LAB image to RGB.
354
    Args:
355
      input: A 3-D (`[H, W, 3]`) or 4-D (`[N, H, W, 3]`) Tensor.
356
      illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
357
        The name of the illuminant (the function is NOT case sensitive).
358
      observer : {"2", "10"}, optional
359
        The aperture angle of the observer.
360
      name: A name for the operation (optional).
361
    Returns:
362
      A 3-D (`[H, W, 3]`) or 4-D (`[N, H, W, 3]`) Tensor.
363
    """
364
    assert input.dtype in (np.float16, np.float32, np.float64)
365
366
    lab = input
367
    lab = unstack(lab, axis=-1)
368
    l, a, b = lab[0], lab[1], lab[2]
369
370
    y = (l + 16.0) / 116.0
371
    x = (a / 500.0) + y
372
    z = y - (b / 200.0)
373
374
    z = np.clip(z, 0, None)
375
376
    xyz = np.stack([x, y, z], axis=-1)
377
378
    xyz = np.where(
379
        xyz > 0.2068966,
380
        np.power(xyz, 3.0),
381
        (xyz - 16.0 / 116.0) / 7.787,
382
    )
383
384
    coords = np.array(illuminants[illuminant.upper()][observer], input.dtype)
385
386
    xyz = xyz * coords
387
388
    return xyz_to_rgb(xyz)
389
390
391
def rgb_to_lab(input, illuminant="D65", observer="2"):
392
    """
393
    Convert a RGB image to CIE LAB.
394
    Args:
395
      input: A 3-D (`[H, W, 3]`) or 4-D (`[N, H, W, 3]`) Tensor.
396
      illuminant : {"A", "D50", "D55", "D65", "D75", "E"}, optional
397
        The name of the illuminant (the function is NOT case sensitive).
398
      observer : {"2", "10"}, optional
399
        The aperture angle of the observer.
400
      name: A name for the operation (optional).
401
    Returns:
402
      A 3-D (`[H, W, 3]`) or 4-D (`[N, H, W, 3]`) Tensor.
403
    """
404
    assert input.dtype in (np.float16, np.float32, np.float64)
405
406
    coords = np.array(illuminants[illuminant.upper()][observer], input.dtype)
407
408
    xyz = rgb_to_xyz(input)
409
410
    xyz = xyz / coords
411
412
    xyz = np.where(
413
        xyz > 0.008856,
414
        np.power(xyz, 1.0 / 3.0),
415
        xyz * 7.787 + 16.0 / 116.0,
416
    )
417
418
    xyz = unstack(xyz, axis=-1)
419
    x, y, z = xyz[0], xyz[1], xyz[2]
420
421
    # Vector scaling
422
    L = (y * 116.0) - 16.0
423
    A = (x - y) * 500.0
424
    B = (y - z) * 200.0
425
426
    return np.stack([L, A, B], axis=-1)
427
428
# -----------------------------------------------------------------------------
429
430
431
# --- Numpy and CV2-based LAB-RGB utility functions. -----------------------------
432
433
def merge_back_cv2(I1: np.ndarray, I2: np.ndarray, I3: np.ndarray) -> np.ndarray:
434
    """Take seperate LAB channels and merge back to give RGB uint8
435
436
    Args:
437
        I1 (np.ndarray): First channel.
438
        I2 (np.ndarray): Second channel.
439
        I3 (np.ndarray): Third channel.
440
441
    Returns:
442
        np.ndarray: RGB uint8 image.
443
    """
444
    I1 *= 2.55
445
    I2 += 128.0
446
    I3 += 128.0
447
    I = np.clip(cv2.merge((I1, I2, I3)), 0, 255).astype(np.uint8)
448
    return cv2.cvtColor(I, cv2.COLOR_LAB2RGB)
449
450
def lab_split_cv2(I: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
451
    """Convert from RGB uint8 to LAB and split into channels
452
453
    Args:
454
        I (np.ndarray): RGB uint8 image.
455
456
    Returns:
457
        np.ndarray: I1, first channel.
458
459
        np.ndarray: I2, first channel.
460
461
        np.ndarray: I3, first channel.
462
    """
463
    I = cv2.cvtColor(I, cv2.COLOR_RGB2LAB)
464
    I = I.astype(np.float32)
465
    I1, I2, I3 = cv2.split(I)
466
    I1 /= 2.55
467
    I2 -= 128.0
468
    I3 -= 128.0
469
    return I1, I2, I3
470
471
# -----------------------------------------------------------------------------
472
473
def lab_split_numpy(I: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
474
    """Convert from RGB uint8 to LAB and split into channels
475
476
    Args:
477
        I (np.ndarray): RGB uint8 image.
478
479
    Returns:
480
        np.ndarray: I1, first channel.
481
482
        np.ndarray: I2, first channel.
483
484
        np.ndarray: I3, first channel.
485
    """
486
    I = I.astype(np.float32)
487
    I /= 255
488
    I = rgb_to_lab(I)
489
    return unstack(I, axis=-1)
490
491
492
def merge_back_numpy(I1: np.ndarray, I2: np.ndarray, I3: np.ndarray) -> np.ndarray:
493
    """Take seperate LAB channels and merge back to give RGB uint8
494
495
    Args:
496
        I1 (np.ndarray): First channel.
497
        I2 (np.ndarray): Second channel.
498
        I3 (np.ndarray): Third channel.
499
500
    Returns:
501
        np.ndarray: RGB uint8 image.
502
    """
503
    I = np.stack((I1, I2, I3), axis=-1)
504
    I = lab_to_rgb(I) * 255
505
    I = I.astype(np.int32)
506
    I = np.clip(I, 0, 255).astype(np.uint8)
507
    return I