a b/brats_toolkit/fusionator.py
1
# -*- coding: utf-8 -*-
2
# Author: Christoph Berger
3
# Script for the fusion of segmentation labels
4
#
5
# Please refer to README.md and LICENSE.md for further documentation
6
# This software is not certified for clinical use.
7
import itertools
8
import logging
9
import math
10
import os
11
import os.path as op
12
13
import numpy as np
14
15
from .util import filemanager as fm
16
from .util import own_itk as oitk
17
from .util.citation_reminder import citation_reminder
18
19
20
class Fusionator(object):
21
    @citation_reminder
22
    def __init__(self, verbose=True):
23
        self.verbose = verbose
24
25
    def _binaryMav(self, candidates, weights=None):
26
        """
27
        binaryMav performs majority vote fusion on an arbitary number of input segmentations with
28
        only two classes each (1 and 0).
29
30
        Args:
31
            candidates (list): the candidate segmentations as binary numpy arrays of same shape
32
            weights (list, optional): associated weights for each segmentation in candidates. Defaults to None.
33
34
        Return
35
            array: a numpy array with the majority vote result
36
        """
37
        num = len(candidates)
38
        if weights == None:
39
            weights = itertools.repeat(1, num)
40
        # manage empty calls
41
        if num == 0:
42
            print("ERROR! No segmentations to fuse.")
43
        elif num == 1:
44
            return candidates[0]
45
        if self.verbose:
46
            print(
47
                "Number of segmentations to be fused using compound majority vote is: ",
48
                num,
49
            )
50
            for c in candidates:
51
                print(
52
                    "Candidate with shape {} and values {} and sum {}".format(
53
                        c.shape, np.unique(c), np.sum(c)
54
                    )
55
                )
56
        # load first segmentation and use it to create initial numpy arrays
57
        temp = candidates[0]
58
        result = np.zeros(temp.shape)
59
        # loop through all available segmentations and tally votes for each class
60
        label = np.zeros(temp.shape)
61
        for c, w in zip(candidates, weights):
62
            if c.max() != 1 or c.min() != 0:
63
                logging.warning(
64
                    "The passed segmentation contains labels other than 1 and 0."
65
                )
66
            print("weight is: " + str(w))
67
            label[c == 1] += 1.0 * w
68
        num = sum(weights)
69
        result[label >= (num / 2.0)] = 1
70
        if self.verbose:
71
            print("Shape of result:", result.shape)
72
            print("Shape of current input array:", temp.shape)
73
            print(
74
                "Labels and datatype of current output:",
75
                result.max(),
76
                result.min(),
77
                result.dtype,
78
            )
79
        return result
80
81
    def _mav(self, candidates, labels=None, weights=None):
82
        """
83
        mav performs majority vote fusion on an arbitary number of input segmentations with
84
        an arbitrary number of labels.
85
86
        Args:
87
            candidates (list): the candidate segmentations as binary numpy arrays of same shape
88
            labels (list, optional): a list of labels present in the candidates. Defaults to None.
89
            weights (list, optional): weights for the fusion. Defaults to None.
90
91
        Returns:
92
            array: a numpy array with the majority vote result
93
        """
94
        num = len(candidates)
95
        if weights == None:
96
            weights = itertools.repeat(1, num)
97
        # manage empty calls
98
        if num == 0:
99
            print("ERROR! No segmentations to fuse.")
100
        if self.verbose:
101
            print(
102
                "Number of segmentations to be fused using compound majority vote is: ",
103
                num,
104
            )
105
        # if no labels are passed, get the labels from the first input file (might lead to misisng labels!)
106
        if labels == None:
107
            labels = np.unique(candidates[0])
108
            for c in candidates:
109
                labels = np.append(labels, np.unique(c))
110
                print(
111
                    "Labels of current candidate: {}, dtype: {}".format(
112
                        np.unique(c), c.dtype
113
                    )
114
                )
115
            labels = np.unique(labels).astype(int)
116
            logging.warning(
117
                "No labels passed, choosing those labels automatically: {}".format(
118
                    labels
119
                )
120
            )
121
        # remove background label
122
        if 0 in labels:
123
            labels = np.delete(labels, 0)
124
        # load first segmentation and use it to create initial numpy arrays
125
        temp = candidates[0]
126
        result = np.zeros(temp.shape)
127
        # loop through all available segmentations and tally votes for each class
128
        print("Labels: {}".format(labels))
129
        for l in sorted(labels, reverse=True):
130
            label = np.zeros(temp.shape)
131
            num = 0
132
            for c, w in zip(candidates, weights):
133
                print("weight is: " + str(w))
134
                label[c == l] += 1.0 * w
135
            num = sum(weights)
136
            print(num)
137
            result[label >= (num / 2.0)] = l
138
        if self.verbose:
139
            print("Shape of result:", result.shape)
140
            print(
141
                "Labels and datatype of result:",
142
                result.max(),
143
                result.min(),
144
                result.dtype,
145
            )
146
        return result
147
148
    def _brats_simple(
149
        self,
150
        candidates,
151
        weights=None,
152
        t=0.05,
153
        stop=25,
154
        inc=0.07,
155
        method="dice",
156
        iterations=25,
157
    ):
158
        """
159
        BRATS DOMAIN ADAPTED!!!!! simple implementation using DICE scoring
160
        Iteratively estimates the accuracy of the segmentations and dynamically assigns weights
161
        for the next iteration. Continues for each label until convergence is reached.
162
163
        Args:
164
            candidates (list): [description]
165
            weights (list, optional): [description]. Defaults to None.
166
            t (float, optional): [description]. Defaults to 0.05.
167
            stop (int, optional): [description]. Defaults to 25.
168
            inc (float, optional): [description]. Defaults to 0.07.
169
            method (str, optional): [description]. Defaults to 'dice'.
170
            iterations (int, optional): [description]. Defaults to 25.
171
            labels (list, optional): [description]. Defaults to None.
172
173
        Raises:
174
            IOError: If no segmentations to be fused are passed
175
176
        Returns:
177
            array: a numpy array with the SIMPLE fusion result
178
        """
179
        # manage empty calls
180
        num = len(candidates)
181
        if num == 0:
182
            print("ERROR! No segmentations to fuse.")
183
            raise IOError("No valid segmentations passed for SIMPLE Fusion")
184
        if self.verbose:
185
            print("Number of segmentations to be fused using SIMPLE is: ", num)
186
        # handle unpassed weights
187
        if weights == None:
188
            weights = itertools.repeat(1, num)
189
        backup_weights = weights  # ugly save to reset weights after each round
190
        # get unique labels for multi-class fusion
191
192
        result = np.zeros(candidates[0].shape)
193
        labels = [2, 1, 4]
194
        logging.info("Fusing a segmentation with the labels: {}".format(labels))
195
        # loop over each label
196
        for l in labels:
197
            if self.verbose:
198
                print("Currently fusing label {}".format(l))
199
            # load first segmentation and use it to create initial numpy arrays IFF it contains labels
200
            if l == 2:
201
                # whole tumor
202
                bin_candidates = [(c > 0).astype(int) for c in candidates]
203
            elif l == 1:
204
                # tumor core
205
                bin_candidates = [((c == 1) | (c == 4)).astype(int) for c in candidates]
206
            else:
207
                # active tumor
208
                bin_candidates = [(c == 4).astype(int) for c in candidates]
209
            if self.verbose:
210
                print(bin_candidates[0].shape)
211
            # baseline estimate
212
            estimate = self._binaryMav(bin_candidates, weights)
213
            # initial convergence baseline
214
            conv = np.sum(estimate)
215
            # check if the estimate was reasonable
216
            if conv == 0:
217
                logging.error("Majority Voting in SIMPLE returned an empty array")
218
                # return np.zeros(candidates[0].shape)
219
            # reset tau before each iteration
220
            tau = t
221
            for i in range(iterations):
222
                t_weights = []  # temporary weights
223
                for c in bin_candidates:
224
                    # score all canidate segmentations
225
                    t_weights.append(
226
                        (self._score(c, estimate, method) + 1) ** 2
227
                    )  # SQUARED DICE!
228
                weights = t_weights
229
                # save maximum score in weights
230
                max_phi = max(weights)
231
                # remove dropout estimates
232
                bin_candidates = [
233
                    c for c, w in zip(bin_candidates, weights) if (w > t * max_phi)
234
                ]
235
                # calculate new estimate
236
                estimate = self._binaryMav(bin_candidates, weights)
237
                # increment tau
238
                tau = tau + inc
239
                # check if it converges
240
                if np.abs(conv - np.sum(estimate)) < stop:
241
                    if self.verbose:
242
                        print(
243
                            "Convergence for label {} after {} iterations reached.".format(
244
                                l, i
245
                            )
246
                        )
247
                    break
248
                conv = np.sum(estimate)
249
            # assign correct label to result
250
            result[estimate == 1] = l
251
            # reset weights
252
            weights = backup_weights
253
        if self.verbose:
254
            print("Shape of result:", result.shape)
255
            print("Shape of current input array:", bin_candidates[0].shape)
256
            print(
257
                "Labels and datatype of current output:",
258
                result.max(),
259
                result.min(),
260
                result.dtype,
261
            )
262
        return result
263
264
    def _simple(
265
        self,
266
        candidates,
267
        weights=None,
268
        t=0.05,
269
        stop=25,
270
        inc=0.07,
271
        method="dice",
272
        iterations=25,
273
        labels=None,
274
    ):
275
        """
276
        simple implementation using DICE scoring
277
        Iteratively estimates the accuracy of the segmentations and dynamically assigns weights
278
        for the next iteration. Continues for each label until convergence is reached.
279
280
        Args:
281
            candidates (list): [description]
282
            weights (list, optional): [description]. Defaults to None.
283
            t (float, optional): [description]. Defaults to 0.05.
284
            stop (int, optional): [description]. Defaults to 25.
285
            inc (float, optional): [description]. Defaults to 0.07.
286
            method (str, optional): [description]. Defaults to 'dice'.
287
            iterations (int, optional): [description]. Defaults to 25.
288
            labels (list, optional): [description]. Defaults to None.
289
290
        Raises:
291
            IOError: If no segmentations to be fused are passed
292
293
        Returns:
294
            array: a numpy array with the SIMPLE fusion result
295
        """
296
        # manage empty calls
297
        num = len(candidates)
298
        if num == 0:
299
            print("ERROR! No segmentations to fuse.")
300
            raise IOError("No valid segmentations passed for SIMPLE Fusion")
301
        if self.verbose:
302
            print("Number of segmentations to be fused using SIMPLE is: ", num)
303
        # handle unpassed weights
304
        if weights == None:
305
            weights = itertools.repeat(1, num)
306
        backup_weights = weights  # ugly save to reset weights after each round
307
        # get unique labels for multi-class fusion
308
        if labels == None:
309
            labels = np.unique(candidates[0])
310
            for c in candidates:
311
                labels = np.append(labels, np.unique(c))
312
                print(
313
                    "Labels of current candidate: {}, dtype: {}".format(
314
                        np.unique(c), c.dtype
315
                    )
316
                )
317
            labels = np.unique(labels).astype(int)
318
            logging.warning(
319
                "No labels passed, choosing those labels automatically: {}".format(
320
                    labels
321
                )
322
            )
323
        result = np.zeros(candidates[0].shape)
324
        # remove background label
325
        if 0 in labels:
326
            labels = np.delete(labels, 0)
327
        logging.info("Fusing a segmentation with the labels: {}".format(labels))
328
        # loop over each label
329
        for l in sorted(labels):
330
            if self.verbose:
331
                print("Currently fusing label {}".format(l))
332
            # load first segmentation and use it to create initial numpy arrays IFF it contains labels
333
            bin_candidates = [(c == l).astype(int) for c in candidates]
334
            if self.verbose:
335
                print(bin_candidates[0].shape)
336
            # baseline estimate
337
            estimate = self._binaryMav(bin_candidates, weights)
338
            # initial convergence baseline
339
            conv = np.sum(estimate)
340
            # check if the estimate was reasonable
341
            if conv == 0:
342
                logging.error("Majority Voting in SIMPLE returned an empty array")
343
                # return np.zeros(candidates[0].shape)
344
            # reset tau before each iteration
345
            tau = t
346
            for i in range(iterations):
347
                t_weights = []  # temporary weights
348
                for c in bin_candidates:
349
                    # score all canidate segmentations
350
                    t_weights.append(
351
                        (self._score(c, estimate, method) + 1) ** 2
352
                    )  # SQUARED DICE!
353
                weights = t_weights
354
                # save maximum score in weights
355
                max_phi = max(weights)
356
                # remove dropout estimates
357
                bin_candidates = [
358
                    c for c, w in zip(bin_candidates, weights) if (w > t * max_phi)
359
                ]
360
                # calculate new estimate
361
                estimate = self._binaryMav(bin_candidates, weights)
362
                # increment tau
363
                tau = tau + inc
364
                # check if it converges
365
                if np.abs(conv - np.sum(estimate)) < stop:
366
                    if self.verbose:
367
                        print(
368
                            "Convergence for label {} after {} iterations reached.".format(
369
                                l, i
370
                            )
371
                        )
372
                    break
373
                conv = np.sum(estimate)
374
            # assign correct label to result
375
            result[estimate == 1] = l
376
            # reset weights
377
            weights = backup_weights
378
        if self.verbose:
379
            print("Shape of result:", result.shape)
380
            print("Shape of current input array:", bin_candidates[0].shape)
381
            print(
382
                "Labels and datatype of current output:",
383
                result.max(),
384
                result.min(),
385
                result.dtype,
386
            )
387
        return result
388
389
    def _dirFuse(self, directory, method="mav", outputPath=None, labels=None):
390
        """
391
        dirFuse [summary]
392
393
        Args:
394
            directory ([type]): [description]
395
            method (str, optional): [description]. Defaults to 'mav'.
396
            outputName ([type], optional): [description]. Defaults to None.
397
        """
398
        if method == "all":
399
            return
400
        candidates = []
401
        weights = []
402
        temp = None
403
        for file in os.listdir(directory):
404
            if file.endswith(".nii.gz"):
405
                # skip existing fusions
406
                if "fusion" in file:
407
                    continue
408
                temp = op.join(directory, file)
409
                try:
410
                    candidates.append(oitk.get_itk_array(oitk.get_itk_image(temp)))
411
                    weights.append(1)
412
                    print("Loaded: " + os.path.join(directory, file))
413
                except Exception as e:
414
                    print(
415
                        "Could not load this file: "
416
                        + file
417
                        + " \nPlease check if this is a valid path and that the files exists. Exception: "
418
                        + e
419
                    )
420
        if method == "mav":
421
            print(
422
                "Orchestra: Now fusing all .nii.gz files in directory {} using MAJORITY VOTING. For more output, set the -v or --verbose flag or instantiate the fusionator class with verbose=true".format(
423
                    directory
424
                )
425
            )
426
            result = self._mav(candidates, labels, weights)
427
        elif method == "simple":
428
            print(
429
                "Orchestra: Now fusing all .nii.gz files in directory {} using SIMPLE. For more output, set the -v or --verbose flag or instantiate the fusionator class with verbose=true".format(
430
                    directory
431
                )
432
            )
433
            result = self._simple(candidates, weights)
434
        elif method == "brats-simple":
435
            print(
436
                "Orchestra: Now fusing all .nii.gz files in directory {} using BRATS-SIMPLE. For more output, set the -v or --verbose flag or instantiate the fusionator class with verbose=true".format(
437
                    directory
438
                )
439
            )
440
            result = self._brats_simple(candidates, weights)
441
        try:
442
            if outputPath == None:
443
                oitk.write_itk_image(
444
                    oitk.make_itk_image(result, proto_image=oitk.get_itk_image(temp)),
445
                    op.join(directory, method + "_fusion.nii.gz"),
446
                )
447
            else:
448
                outputDir = op.dirname(outputPath)
449
                os.makedirs(outputDir, exist_ok=True)
450
                oitk.write_itk_image(
451
                    oitk.make_itk_image(result, proto_image=oitk.get_itk_image(temp)),
452
                    outputPath,
453
                )
454
            logging.info(
455
                "Segmentation Fusion with method {} saved in directory {}.".format(
456
                    method, directory
457
                )
458
            )
459
        except Exception as e:
460
            print("Very bad, this should also be logged somewhere: " + str(e))
461
            logging.exception(
462
                "Issues while saving the resulting segmentation: {}".format(str(e))
463
            )
464
465
    def fuse(self, segmentations, outputPath, method="mav", weights=None, labels=None):
466
        """
467
        fuse [summary]
468
469
        Args:
470
            segmentations ([type]): [description]
471
            outputPath ([type]): [description]
472
            method (str, optional): [description]. Defaults to 'mav'.
473
            weights ([type], optional): [description]. Defaults to None.
474
475
        Raises:
476
            IOError: [description]
477
        """
478
        candidates = []
479
        if weights is not None:
480
            if len(weights) != len(segmentations):
481
                raise IOError(
482
                    "Please pass a matching number of weights and segmentation files"
483
                )
484
            w_weights = weights
485
        else:
486
            w_weights = []
487
        for seg in segmentations:
488
            if seg.endswith(".nii.gz"):
489
                try:
490
                    candidates.append(oitk.get_itk_array(oitk.get_itk_image(seg)))
491
                    if weights is None:
492
                        w_weights.append(1)
493
                    print("Loaded: " + seg)
494
                except Exception as e:
495
                    print(
496
                        "Could not load this file: "
497
                        + seg
498
                        + " \nPlease check if this is a valid path and that the files exists. Exception: "
499
                        + str(e)
500
                    )
501
                    raise
502
        if method == "mav":
503
            print(
504
                "Orchestra: Now fusing all passed .nii.gz files using MAJORITY VOTING. For more output, set the -v or --verbose flag or instantiate the fusionator class with verbose=true"
505
            )
506
            result = self._mav(candidates, labels=labels, weights=w_weights)
507
        elif method == "simple":
508
            print(
509
                "Orchestra: Now fusing all passed .nii.gz files in using SIMPLE. For more output, set the -v or --verbose flag or instantiate the fusionator class with verbose=true"
510
            )
511
            result = self._simple(candidates, w_weights)
512
        elif method == "brats-simple":
513
            print(
514
                "Orchestra: Now fusing all .nii.gz files in directory {} using BRATS-SIMPLE. For more output, set the -v or --verbose flag or instantiate the fusionator class with verbose=true"
515
            )
516
            result = self._brats_simple(candidates, w_weights)
517
        try:
518
            outputDir = op.dirname(outputPath)
519
            os.makedirs(outputDir, exist_ok=True)
520
            oitk.write_itk_image(
521
                oitk.make_itk_image(result, proto_image=oitk.get_itk_image(seg)),
522
                outputPath,
523
            )
524
            logging.info(
525
                "Segmentation Fusion with method {} saved as {}.".format(
526
                    method, outputPath
527
                )
528
            )
529
        except Exception as e:
530
            print("Very bad, this should also be logged somewhere: " + str(e))
531
            logging.exception(
532
                "Issues while saving the resulting segmentation: {}".format(str(e))
533
            )
534
535
    def _score(self, seg, gt, method="dice"):
536
        """Calculates a similarity score based on the
537
        method specified in the parameters
538
        Input: Numpy arrays to be compared, need to have the
539
        same dimensions (shape)
540
        Default scoring method: DICE coefficient
541
        method may be:  'dice'
542
                        'auc'
543
                        'bdice'
544
        returns: a score [0,1], 1 for identical inputs
545
        """
546
        try:
547
            # True Positive (TP): we predict a label of 1 (positive) and the true label is 1.
548
            TP = np.sum(np.logical_and(seg == 1, gt == 1))
549
            # True Negative (TN): we predict a label of 0 (negative) and the true label is 0.
550
            TN = np.sum(np.logical_and(seg == 0, gt == 0))
551
            # False Positive (FP): we predict a label of 1 (positive), but the true label is 0.
552
            FP = np.sum(np.logical_and(seg == 1, gt == 0))
553
            # False Negative (FN): we predict a label of 0 (negative), but the true label is 1.
554
            FN = np.sum(np.logical_and(seg == 0, gt == 1))
555
            FPR = FP / (FP + TN)
556
            FNR = FN / (FN + TP)
557
            TPR = TP / (TP + FN)
558
            TNR = TN / (TN + FP)
559
        except ValueError:
560
            print("Value error encountered!")
561
            return 0
562
        # faster dice? Oh yeah!
563
        if method == "dice":
564
            # default dice score
565
            score = 2 * TP / (2 * TP + FP + FN)
566
        elif method == "auc":
567
            # AUC scoring
568
            score = 1 - (FPR + FNR) / 2
569
        elif method == "bdice":
570
            # biased dice towards false negatives
571
            score = 2 * TP / (2 * TP + FN)
572
        elif method == "spec":
573
            # specificity
574
            score = TN / (TN + FP)
575
        elif method == "sens":
576
            # sensitivity
577
            score = TP / (TP + FN)
578
        elif method == "toterr":
579
            score = (FN + FP) / (155 * 240 * 240)
580
        elif method == "ppv":
581
            prev = np.sum(gt) / (155 * 240 * 240)
582
            temp = TPR * prev
583
            score = (temp) / (temp + (1 - TNR) * (1 - prev))
584
        else:
585
            score = 0
586
        if np.isnan(score) or math.isnan(score):
587
            score = 0
588
        return score