Diff of /predictor.py [000000] .. [bb7f56]

Switch to unified view

a b/predictor.py
1
#!/usr/bin/env python
2
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
# ==============================================================================
16
17
import os
18
import code
19
import numpy as np
20
import torch
21
from scipy.stats import norm
22
from collections import OrderedDict
23
from multiprocessing import Pool
24
import pickle
25
from copy import deepcopy
26
import pandas as pd
27
28
import utils.exp_utils as utils
29
from plotting import plot_batch_prediction
30
31
32
class Predictor:
33
    """
34
    Prediction pipeline:
35
    - receives a patched patient image (n_patches, c, y, x, (z)) from patient data loader.
36
    - forwards patches through model in chunks of batch_size. (method: batch_tiling_forward)
37
    - unmolds predictions (boxes and segmentations) to original patient coordinates. (method: spatial_tiling_forward)
38
39
    Ensembling (mode == 'test'):
40
    - for inference, forwards 4 mirrored versions of image to through model and unmolds predictions afterwards
41
      accordingly (method: data_aug_forward)
42
    - for inference, loads multiple parameter-sets of the trained model corresponding to different epochs. for each
43
      parameter-set loops over entire test set, runs prediction pipeline for each patient. (method: predict_test_set)
44
45
    Consolidation of predictions:
46
    - consolidates a patient's predictions (boxes, segmentations) collected over patches, data_aug- and temporal ensembling,
47
      performs clustering and weighted averaging (external function: apply_wbc_to_patient) to obtain consistent outptus.
48
    - for 2D networks, consolidates box predictions to 3D cubes via clustering (adaption of non-maximum surpression).
49
      (external function: merge_2D_to_3D_preds_per_patient)
50
51
    Ground truth handling:
52
    - dissmisses any ground truth boxes returned by the model (happens in validation mode, patch-based groundtruth)
53
    - if provided by data loader, adds 3D ground truth to the final predictions to be passed to the evaluator.
54
    """
55
    def __init__(self, cf, net, logger, mode):
56
57
        self.cf = cf
58
        self.logger = logger
59
60
        # mode is 'val' for patient-based validation/monitoring and 'test' for inference.
61
        self.mode = mode
62
63
        # model instance. In validation mode, contains parameters of current epoch.
64
        self.net = net
65
66
        # rank of current epoch loaded (for temporal averaging). this info is added to each prediction,
67
        # for correct weighting during consolidation.
68
        self.rank_ix = '0'
69
70
        # number of ensembled models. used to calculate the number of expected predictions per position
71
        # during consolidation of predictions. Default is 1 (no ensembling, e.g. in validation).
72
        self.n_ens = 1
73
74
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
76
        if self.mode == 'test':
77
            try:
78
                self.epoch_ranking = np.load(os.path.join(self.cf.fold_dir, 'epoch_ranking.npy'))[:cf.test_n_epochs]
79
            except:
80
                raise RuntimeError('no epoch ranking file in fold directory. '
81
                                   'seems like you are trying to run testing without prior training...')
82
            self.n_ens = cf.test_n_epochs
83
            if self.cf.test_aug:
84
                self.n_ens *= 4
85
86
            self.example_plot_dir = os.path.join(cf.test_dir, "example_plots")
87
            os.makedirs(self.example_plot_dir, exist_ok=True)
88
89
90
    def predict_patient(self, batch):
91
        """
92
        predicts one patient.
93
        called either directly via loop over validation set in exec.py (mode=='val')
94
        or from self.predict_test_set (mode=='test).
95
        in val mode:  adds 3D ground truth info to predictions and runs consolidation and 2Dto3D merging of predictions.
96
        in test mode: returns raw predictions (ground truth addition, consolidation, 2D to 3D merging are
97
                      done in self.predict_test_set, because patient predictions across several epochs might be needed
98
                      to be collected first, in case of temporal ensembling).
99
        :return. results_dict: stores the results for one patient. dictionary with keys:
100
                 - 'boxes': list over batch elements. each element is a list over boxes, where each box is
101
                            one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions
102
                            (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions.
103
                 - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z))
104
                 - losses (only in validation mode)
105
        """
106
        #self.logger.info('\revaluating patient {} for fold {} '.format(batch['pid'], self.cf.fold))
107
        print('\revaluating patient {} for fold {} '.format(batch['pid'], self.cf.fold), end="", flush=True)
108
109
        # True if patient is provided in patches and predictions need to be tiled.
110
        self.patched_patient = 'patch_crop_coords' in batch.keys()
111
112
        # forward batch through prediction pipeline.
113
        results_dict = self.data_aug_forward(batch)
114
115
        if self.mode == 'val':
116
            for b in range(batch['patient_bb_target'].shape[0]):
117
                for t in range(len(batch['patient_bb_target'][b])):
118
                    results_dict['boxes'][b].append({'box_coords': batch['patient_bb_target'][b][t],
119
                                                     'box_label': batch['patient_roi_labels'][b][t],
120
                                                     'box_type': 'gt'})
121
122
            if self.patched_patient:
123
                wcs_input = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.wcs_iou, self.n_ens]
124
                results_dict['boxes'] = apply_wbc_to_patient(wcs_input)[0]
125
126
            if self.cf.merge_2D_to_3D_preds:
127
                merge_dims_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.merge_3D_iou]
128
                results_dict['boxes'] = merge_2D_to_3D_preds_per_patient(merge_dims_inputs)[0]
129
130
        return results_dict
131
132
133
    def predict_test_set(self, batch_gen, return_results=True):
134
        """
135
        wrapper around test method, which loads multiple (or one) epoch parameters (temporal ensembling), loops through
136
        the test set and collects predictions per patient. Also flattens the results per patient and epoch
137
        and adds optional ground truth boxes for evaluation. Saves out the raw result list for later analysis and
138
        optionally consolidates and returns predictions immediately.
139
        :return: (optionally) list_of_results_per_patient: list over patient results. each entry is a dict with keys:
140
                 - 'boxes': list over batch elements. each element is a list over boxes, where each box is
141
                            one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions
142
                            (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions.
143
                 - 'seg_preds': not implemented yet. todo for evaluation of instance/semantic segmentation.
144
        """
145
        dict_of_patient_results = OrderedDict()
146
147
        # get paths of all parameter sets to be loaded for temporal ensembling. (or just one for no temp. ensembling).
148
        weight_paths = [os.path.join(self.cf.fold_dir, '{}_best_checkpoint'.format(epoch), 'params.pth') for epoch in
149
                        self.epoch_ranking]
150
151
        print (weight_paths)
152
153
        n_test_plots = min(batch_gen['n_test'], 1)
154
155
        for rank_ix, weight_path in enumerate(weight_paths):
156
157
            self.logger.info(('tmp ensembling over rank_ix:{} epoch:{}'.format(rank_ix, weight_path)))
158
            #code.interact(local=locals())
159
            self.net.load_state_dict(torch.load(weight_path))
160
            self.net.eval()
161
            self.rank_ix = str(rank_ix)  # get string of current rank for unique patch ids.
162
            plot_batches = np.random.choice(np.arange(batch_gen['n_test']), size=n_test_plots, replace=False)
163
164
            with torch.no_grad():
165
                for i in range(batch_gen['n_test']):
166
167
                    batch = next(batch_gen['test'])
168
169
                    # store batch info in patient entry of results dict.
170
                    if rank_ix == 0:
171
                        dict_of_patient_results[batch['pid']] = {}
172
                        dict_of_patient_results[batch['pid']]['results_dicts'] = []
173
                        dict_of_patient_results[batch['pid']]['patient_bb_target'] = batch['patient_bb_target']
174
                        dict_of_patient_results[batch['pid']]['patient_roi_labels'] = batch['patient_roi_labels']
175
176
                    # call prediction pipeline and store results in dict.
177
                    results_dict = self.predict_patient(batch)
178
                    dict_of_patient_results[batch['pid']]['results_dicts'].append({"boxes": results_dict['boxes']})
179
                    
180
                    if i in plot_batches and not self.patched_patient:
181
                        # view qualitative results of random test case
182
                        # plotting for patched patients is too expensive, thus not done. Change at will.
183
                        try:
184
                            out_file = os.path.join(self.example_plot_dir,
185
                                                    'batch_example_test_{}_rank_{}.png'.format(self.cf.fold,
186
                                                                                               rank_ix))
187
                            results_for_plotting = deepcopy(results_dict)
188
                            # seg preds of test augs are included separately. for viewing, only show aug 0 (merging
189
                            # would need multiple changes, incl in every model).
190
                            if results_for_plotting["seg_preds"].shape[1] > 1:
191
                                results_for_plotting["seg_preds"] = results_dict['seg_preds'][:, [0]]
192
                            for bix in range(batch["seg"].shape[0]): # batch dim should be 1
193
                                for tix in range(len(batch['bb_target'][bix])):
194
                                    results_for_plotting['boxes'][bix].append({'box_coords': batch['bb_target'][bix][tix],
195
                                                                       'box_label': batch['class_target'][bix][tix],
196
                                                                       'box_type': 'gt'})
197
                            utils.split_off_process(plot_batch_prediction, batch, results_for_plotting, self.cf,
198
                                                    outfile=out_file, suptitle="Test plot:\nunmerged TTA overlayed.")
199
                        except Exception as e:
200
                            self.logger.info("WARNING: error in plotting example test batch: {}".format(e))
201
202
203
        self.logger.info('finished predicting test set. starting post-processing of predictions.')
204
        results_per_patient = []
205
206
        # loop over patients again to flatten results across epoch predictions.
207
        # if provided, add ground truth boxes for evaluation.
208
        for pid, p_dict in dict_of_patient_results.items():
209
210
            tmp_ens_list = p_dict['results_dicts']
211
            results_dict = {}
212
            # collect all boxes/seg_preds of same batch_instance over temporal instances.
213
            b_size = len(tmp_ens_list[0]["boxes"])
214
            results_dict['boxes'] = [[item for rank_dict in tmp_ens_list for item in rank_dict["boxes"][batch_instance]]
215
                                     for batch_instance in range(b_size)]
216
217
            # TODO return for instance segmentation:
218
            # results_dict['seg_preds'] = np.mean(results_dict['seg_preds'], 1)[:, None]
219
            # results_dict['seg_preds'] = np.array([[item for d in tmp_ens_list for item in d['seg_preds'][batch_instance]]
220
            #                                       for batch_instance in range(len(tmp_ens_list[0]['boxes']))])
221
222
            # add 3D ground truth boxes for evaluation.
223
            for b in range(p_dict['patient_bb_target'].shape[0]):
224
                for t in range(len(p_dict['patient_bb_target'][b])):
225
                    results_dict['boxes'][b].append({'box_coords': p_dict['patient_bb_target'][b][t],
226
                                                     'box_label': p_dict['patient_roi_labels'][b][t],
227
                                                     'box_type': 'gt'})
228
            results_per_patient.append([results_dict, pid])
229
230
        # save out raw predictions.
231
        out_string = 'raw_pred_boxes_hold_out_list' if self.cf.hold_out_test_set else 'raw_pred_boxes_list'
232
        with open(os.path.join(self.cf.fold_dir, '{}.pickle'.format(out_string)), 'wb') as handle:
233
            pickle.dump(results_per_patient, handle)
234
235
        if return_results:
236
            final_patient_box_results = [(res_dict["boxes"], pid) for res_dict, pid in results_per_patient]
237
            # consolidate predictions.
238
            self.logger.info('applying wcs to test set predictions with iou = {} and n_ens = {}.'.format(
239
                self.cf.wcs_iou, self.n_ens))
240
            pool = Pool(processes=8)
241
            mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.wcs_iou, self.n_ens] for ii in final_patient_box_results]
242
            final_patient_box_results = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1)
243
            pool.close()
244
            pool.join()
245
246
            # merge 2D boxes to 3D cubes. (if model predicts 2D but evaluation is run in 3D)
247
            if self.cf.merge_2D_to_3D_preds:
248
                self.logger.info('applying 2Dto3D merging to test set predictions with iou = {}.'.format(self.cf.merge_3D_iou))
249
                pool = Pool(processes=6)
250
                mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in final_patient_box_results]
251
                final_patient_box_results = pool.map(merge_2D_to_3D_preds_per_patient, mp_inputs, chunksize=1)
252
                pool.close()
253
                pool.join()
254
255
            # final_patient_box_results holds [avg_boxes, pid] if wbc
256
            for ix in range(len(results_per_patient)):
257
                assert results_per_patient[ix][1] == final_patient_box_results[ix][1], "should be same pid"
258
                results_per_patient[ix][0]["boxes"] = final_patient_box_results[ix][0]
259
260
            out_string = 'wbc_pred_boxes_hold_out_list' if self.cf.hold_out_test_set else 'wbc_pred_boxes_list'
261
            with open(os.path.join(self.cf.fold_dir, '{}.pickle'.format(out_string)), 'wb') as handle:
262
               pickle.dump(results_per_patient, handle)
263
264
265
            return results_per_patient
266
267
268
    def load_saved_predictions(self, apply_wbc=False):
269
        """
270
        loads raw predictions saved by self.predict_test_set. consolidates and merges 2D boxes to 3D cubes for evaluation.
271
        (if model predicts 2D but evaluation is run in 3D)
272
        :return: (optionally) results_list: list over patient results. each entry is a dict with keys:
273
                 - 'boxes': list over batch elements. each element is a list over boxes, where each box is
274
                            one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions
275
                            (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions.
276
                 - 'seg_preds': not implemented yet. todo for evaluation of instance/semantic segmentation.
277
        """
278
279
        # load predictions for a single test-set fold.
280
        results_file = 'raw_pred_boxes_hold_out_list.pickle' if self.cf.hold_out_test_set else 'raw_pred_boxes_list.pickle'
281
        if not self.cf.hold_out_test_set or not self.cf.ensemble_folds:
282
            with open(os.path.join(self.cf.fold_dir, results_file), 'rb') as handle:
283
                results_list = pickle.load(handle)
284
            box_results_list = [(res_dict["boxes"], pid) for res_dict, pid in results_list]
285
            da_factor = 4 if self.cf.test_aug else 1
286
            n_ens = self.cf.test_n_epochs * da_factor
287
            self.logger.info('loaded raw test set predictions with n_patients = {} and n_ens = {}'.format(
288
                len(results_list), n_ens))
289
290
        # if hold out test set was perdicted, aggregate predictions of all trained models
291
        # corresponding to all CV-folds and flatten them.
292
        else:
293
            self.logger.info("loading saved predictions of hold-out test set and ensembling over folds.")
294
            fold_dirs = sorted([os.path.join(self.cf.exp_dir, f) for f in os.listdir(self.cf.exp_dir) if
295
                                os.path.isdir(os.path.join(self.cf.exp_dir, f)) and f.startswith("fold")])
296
297
            results_list = []
298
            folds_loaded = 0
299
            for fold in range(self.cf.n_cv_splits):
300
                fold_dir = os.path.join(self.cf.exp_dir, 'fold_{}'.format(fold))
301
                if fold_dir in fold_dirs:
302
                    with open(os.path.join(fold_dir, results_file), 'rb') as handle:
303
                        fold_list = pickle.load(handle)
304
                        results_list += fold_list
305
                        folds_loaded += 1
306
                else:
307
                    self.logger.info("Skipping fold {} since no saved predictions found.".format(fold))
308
            box_results_list = []
309
            for res_dict, pid in results_list: #without filtering gt out:
310
                box_results_list.append((res_dict['boxes'], pid))
311
312
            da_factor = 4 if self.cf.test_aug else 1
313
            n_ens = self.cf.test_n_epochs * da_factor * folds_loaded
314
315
        # consolidate predictions.
316
        if apply_wbc:
317
            self.logger.info('applying wcs to test set predictions with iou = {} and n_ens = {}.'.format(
318
                self.cf.wcs_iou, n_ens))
319
            pool = Pool(processes=6)
320
            mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.wcs_iou, n_ens] for ii in box_results_list]
321
            box_results_list = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1)
322
            pool.close()
323
            pool.join()
324
325
        # merge 2D box predictions to 3D cubes (if model predicts 2D but evaluation is run in 3D)
326
        if self.cf.merge_2D_to_3D_preds:
327
            self.logger.info(
328
                'applying 2Dto3D merging to test set predictions with iou = {}.'.format(self.cf.merge_3D_iou))
329
            pool = Pool(processes=6)
330
            mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in box_results_list]
331
            box_results_list = pool.map(merge_2D_to_3D_preds_per_patient, mp_inputs, chunksize=1)
332
            pool.close()
333
            pool.join()
334
335
336
        for ix in range(len(results_list)):
337
            assert np.all(
338
                results_list[ix][1] == box_results_list[ix][1]), "pid mismatch between loaded and aggregated results"
339
            results_list[ix][0]["boxes"] = box_results_list[ix][0]
340
341
        return results_list  # holds (results_dict, pid)
342
343
344
    def data_aug_forward(self, batch):
345
        """
346
        in val_mode: passes batch through to spatial_tiling method without data_aug.
347
        in test_mode: if cf.test_aug is set in configs, createst 4 mirrored versions of the input image,
348
        passes all of them to the next processing step (spatial_tiling method) and re-transforms returned predictions
349
        to original image version.
350
        :return. results_dict: stores the results for one patient. dictionary with keys:
351
                 - 'boxes': list over batch elements. each element is a list over boxes, where each box is
352
                            one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions,
353
                            and a dummy batch dimension of 1 for 3D predictions.
354
                 - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z))
355
                 - losses (only in validation mode)
356
        """
357
        patch_crops = batch['patch_crop_coords'] if self.patched_patient else None
358
        results_list = [self.spatial_tiling_forward(batch, patch_crops)]
359
        org_img_shape = batch['original_img_shape']
360
361
        if self.mode == 'test' and self.cf.test_aug:
362
363
            if self.patched_patient:
364
                # apply mirror transformations to patch-crop coordinates, for correct tiling in spatial_tiling method.
365
                mirrored_patch_crops = get_mirrored_patch_crops(patch_crops, batch['original_img_shape'])
366
            else:
367
                mirrored_patch_crops = [None] * 3
368
369
            img = np.copy(batch['data'])
370
371
            # first mirroring: y-axis.
372
            batch['data'] = np.flip(img, axis=2).copy()
373
            chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[0], n_aug='1')
374
            # re-transform coordinates.
375
            for ix in range(len(chunk_dict['boxes'])):
376
                for boxix in range(len(chunk_dict['boxes'][ix])):
377
                    coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy()
378
                    coords[0] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][2]
379
                    coords[2] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][0]
380
                    assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()]
381
                    assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()]
382
                    chunk_dict['boxes'][ix][boxix]['box_coords'] = coords
383
            # re-transform segmentation predictions.
384
            chunk_dict['seg_preds'] = np.flip(chunk_dict['seg_preds'], axis=2)
385
            results_list.append(chunk_dict)
386
387
            # second mirroring: x-axis.
388
            batch['data'] = np.flip(img, axis=3).copy()
389
            chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[1], n_aug='2')
390
            # re-transform coordinates.
391
            for ix in range(len(chunk_dict['boxes'])):
392
                for boxix in range(len(chunk_dict['boxes'][ix])):
393
                    coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy()
394
                    coords[1] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][3]
395
                    coords[3] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][1]
396
                    assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()]
397
                    assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()]
398
                    chunk_dict['boxes'][ix][boxix]['box_coords'] = coords
399
            # re-transform segmentation predictions.
400
            chunk_dict['seg_preds'] = np.flip(chunk_dict['seg_preds'], axis=3)
401
            results_list.append(chunk_dict)
402
403
            # third mirroring: y-axis and x-axis.
404
            batch['data'] = np.flip(np.flip(img, axis=2), axis=3).copy()
405
            chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[2], n_aug='3')
406
            # re-transform coordinates.
407
            for ix in range(len(chunk_dict['boxes'])):
408
                for boxix in range(len(chunk_dict['boxes'][ix])):
409
                    coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy()
410
                    coords[0] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][2]
411
                    coords[2] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][0]
412
                    coords[1] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][3]
413
                    coords[3] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][1]
414
                    assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()]
415
                    assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()]
416
                    chunk_dict['boxes'][ix][boxix]['box_coords'] = coords
417
            # re-transform segmentation predictions.
418
            chunk_dict['seg_preds'] = np.flip(np.flip(chunk_dict['seg_preds'], axis=2), axis=3).copy()
419
            results_list.append(chunk_dict)
420
421
            batch['data'] = img
422
423
        # aggregate all boxes/seg_preds per batch element from data_aug predictions.
424
        results_dict = {}
425
        results_dict['boxes'] = [[item for d in results_list for item in d['boxes'][batch_instance]]
426
                                 for batch_instance in range(org_img_shape[0])]
427
        results_dict['seg_preds'] = np.array([[item for d in results_list for item in d['seg_preds'][batch_instance]]
428
                                              for batch_instance in range(org_img_shape[0])])
429
        if self.mode == 'val':
430
            try:
431
                results_dict['torch_loss'] = results_list[0]['torch_loss']
432
                results_dict['class_loss'] = results_list[0]['class_loss']
433
            except KeyError:
434
                pass
435
        return results_dict
436
437
438
    def spatial_tiling_forward(self, batch, patch_crops=None, n_aug='0'):
439
        """
440
        forwards batch to batch_tiling_forward method and receives and returns a dictionary with results.
441
        if patch-based prediction, the results received from batch_tiling_forward will be on a per-patch-basis.
442
        this method uses the provided patch_crops to re-transform all predictions to whole-image coordinates.
443
        Patch-origin information of all box-predictions will be needed for consolidation, hence it is stored as
444
        'patch_id', which is a unique string for each patch (also takes current data aug and temporal epoch instances
445
        into account). all box predictions get additional information about the amount overlapping patches at the
446
        respective position (used for consolidation).
447
        :return. results_dict: stores the results for one patient. dictionary with keys:
448
                 - 'boxes': list over batch elements. each element is a list over boxes, where each box is
449
                            one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions,
450
                            and a dummy batch dimension of 1 for 3D predictions.
451
                 - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z))
452
                 - losses (only in validation mode)
453
        """
454
        if patch_crops is not None:
455
456
            patches_dict = self.batch_tiling_forward(batch)
457
458
            results_dict = {'boxes': [[] for _ in range(batch['original_img_shape'][0])]}
459
460
            # instanciate segemntation output array. Will contain averages over patch predictions.
461
            out_seg_preds = np.zeros(batch['original_img_shape'], dtype=np.float16)[:, 0][:, None]
462
            # counts patch instances per pixel-position.
463
            patch_overlap_map = np.zeros_like(out_seg_preds, dtype='uint8')
464
465
            #unmold segmentation outputs. loop over patches.
466
            for pix, pc in enumerate(patch_crops):
467
                if self.cf.dim == 3:
468
                    out_seg_preds[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += patches_dict['seg_preds'][pix][None]
469
                    patch_overlap_map[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += 1
470
                else:
471
                    out_seg_preds[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += patches_dict['seg_preds'][pix]
472
                    patch_overlap_map[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += 1
473
474
            # take mean in overlapping areas.
475
            out_seg_preds[patch_overlap_map > 0] /= patch_overlap_map[patch_overlap_map > 0]
476
            results_dict['seg_preds'] = out_seg_preds
477
478
            # unmold box outputs. loop over patches.
479
            for pix, pc in enumerate(patch_crops):
480
                patch_boxes = patches_dict['boxes'][pix]
481
482
                for box in patch_boxes:
483
484
                    # add unique patch id for consolidation of predictions.
485
                    box['patch_id'] = self.rank_ix + '_' + n_aug + '_' + str(pix)
486
487
                    # boxes from the edges of a patch have a lower prediction quality, than the ones at patch-centers.
488
                    # hence they will be downweighted for consolidation, using the 'box_patch_center_factor', which is
489
                    # obtained by a normal distribution over positions in the patch and average over spatial dimensions.
490
                    # Also the info 'box_n_overlaps' is stored for consolidation, which depicts the amount over
491
                    # overlapping patches at the box's position.
492
                    c = box['box_coords']
493
                    box_centers = [(c[ii] + c[ii + 2]) / 2 for ii in range(2)]
494
                    if self.cf.dim == 3:
495
                        box_centers.append((c[4] + c[5]) / 2)
496
                    box['box_patch_center_factor'] = np.mean(
497
                        [norm.pdf(bc, loc=pc, scale=pc * 0.8) * np.sqrt(2 * np.pi) * pc * 0.8 for bc, pc in
498
                         zip(box_centers, np.array(self.cf.patch_size) / 2)])
499
                    if self.cf.dim == 3:
500
                        c += np.array([pc[0], pc[2], pc[0], pc[2], pc[4], pc[4]])
501
                        int_c = [int(np.floor(ii)) if ix%2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)]
502
                        box['box_n_overlaps'] = np.mean(patch_overlap_map[:, :, int_c[1]:int_c[3], int_c[0]:int_c[2], int_c[4]:int_c[5]])
503
                        results_dict['boxes'][0].append(box)
504
                    else:
505
                        c += np.array([pc[0], pc[2], pc[0], pc[2]])
506
                        int_c = [int(np.floor(ii)) if ix % 2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)]
507
                        box['box_n_overlaps'] = np.mean(patch_overlap_map[pc[4], :, int_c[1]:int_c[3], int_c[0]:int_c[2]])
508
                        results_dict['boxes'][pc[4]].append(box)
509
510
            if self.mode == 'val':
511
                try:
512
                    results_dict['torch_loss'] = patches_dict['torch_loss']
513
                    results_dict['class_loss'] = patches_dict['class_loss']
514
                except KeyError:
515
                    pass
516
        # if predictions are not patch-based:
517
        # add patch-origin info to boxes (entire image is the same patch with overlap=1) and return results.
518
        else:
519
            results_dict = self.batch_tiling_forward(batch)
520
            for b in results_dict['boxes']:
521
                for box in b:
522
                    box['box_patch_center_factor'] = 1
523
                    box['box_n_overlaps'] = 1
524
                    box['patch_id'] = self.rank_ix + '_' + n_aug
525
526
        return results_dict
527
528
529
    def batch_tiling_forward(self, batch):
530
        """
531
        calls the actual network forward method. in patch-based prediction, the batch dimension might be overladed
532
        with n_patches >> batch_size, which would exceed gpu memory. In this case, batches are processed in chunks of
533
        batch_size. validation mode calls the train method to monitor losses (returned ground truth objects are discarded).
534
        test mode calls the test forward method, no ground truth required / involved.
535
        :return. results_dict: stores the results for one patient. dictionary with keys:
536
                 - 'boxes': list over batch elements. each element is a list over boxes, where each box is
537
                            one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions,
538
                            and a dummy batch dimension of 1 for 3D predictions.
539
                 - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z))
540
                 - losses (only in validation mode)
541
        """
542
        #self.logger.info('forwarding (patched) patient with shape: {}'.format(batch['data'].shape))
543
544
        img = batch['data']
545
546
        #batch['data'] = torch.from_numpy(batch['data']).float().to(self.device)
547
548
        if img.shape[0] <= self.cf.batch_size:
549
550
            if self.mode == 'val':
551
                # call training method to monitor losses
552
                results_dict = self.net.train_forward(batch, is_validation=True)
553
                # discard returned ground-truth boxes (also training info boxes).
554
                results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']]
555
            else:
556
                results_dict = self.net.test_forward(batch, return_masks=self.cf.return_masks_in_test)
557
558
        else:
559
            split_ixs = np.split(np.arange(img.shape[0]), np.arange(img.shape[0])[::self.cf.batch_size])
560
            chunk_dicts = []
561
            for chunk_ixs in split_ixs[1:]:  # first split is elements before 0, so empty
562
                b = {k: batch[k][chunk_ixs] for k in batch.keys()
563
                     if (isinstance(batch[k], np.ndarray) and batch[k].shape[0] == img.shape[0])}
564
                if self.mode == 'val':
565
                    chunk_dicts += [self.net.train_forward(b, is_validation=True)]
566
                else:
567
                    chunk_dicts += [self.net.test_forward(b, return_masks=self.cf.return_masks_in_test)]
568
569
570
            results_dict = {}
571
            # flatten out batch elements from chunks ([chunk, chunk] -> [b, b, b, b, ...])
572
            results_dict['boxes'] = [item for d in chunk_dicts for item in d['boxes']]
573
            results_dict['seg_preds'] = np.array([item for d in chunk_dicts for item in d['seg_preds']])
574
575
            if self.mode == 'val':
576
                try:
577
                    # estimate metrics by mean over batch_chunks. Most similar to training metrics.
578
                    results_dict['torch_loss'] = torch.mean(torch.cat([d['torch_loss'] for d in chunk_dicts]))
579
                    results_dict['class_loss'] = np.mean([d['class_loss'] for d in chunk_dicts])
580
                except KeyError:
581
                    # losses are not necessarily monitored
582
                    pass
583
                # discard returned ground-truth boxes (also training info boxes).
584
                results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']]
585
586
        return results_dict
587
588
589
590
def apply_wbc_to_patient(inputs):
591
    """
592
    wrapper around prediction box consolidation: weighted cluster scoring (wcs). processes a single patient.
593
    loops over batch elements in patient results (1 in 3D, slices in 2D) and foreground classes,
594
    aggregates and stores results in new list.
595
    :return. patient_results_list: list over batch elements. each element is a list over boxes, where each box is
596
                                 one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D
597
                                 predictions, and a dummy batch dimension of 1 for 3D predictions.
598
    :return. pid: string. patient id.
599
    """
600
    in_patient_results_list, pid, class_dict, wcs_iou, n_ens = inputs
601
    out_patient_results_list = [[] for _ in range(len(in_patient_results_list))]
602
603
    for bix, b in enumerate(in_patient_results_list):
604
605
        for cl in list(class_dict.keys()):
606
607
            boxes = [(ix, box) for ix, box in enumerate(b) if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)]
608
            box_coords = np.array([b[1]['box_coords'] for b in boxes])
609
            box_scores = np.array([b[1]['box_score'] for b in boxes])
610
            box_center_factor = np.array([b[1]['box_patch_center_factor'] for b in boxes])
611
            box_n_overlaps = np.array([b[1]['box_n_overlaps'] for b in boxes])
612
            box_patch_id = np.array([b[1]['patch_id'] for b in boxes])
613
614
            if 0 not in box_scores.shape:
615
                keep_scores, keep_coords = weighted_box_clustering(
616
                    np.concatenate((box_coords, box_scores[:, None], box_center_factor[:, None],
617
                                    box_n_overlaps[:, None]), axis=1), box_patch_id, wcs_iou, n_ens)
618
619
                for boxix in range(len(keep_scores)):
620
                    out_patient_results_list[bix].append({'box_type': 'det', 'box_coords': keep_coords[boxix],
621
                                             'box_score': keep_scores[boxix], 'box_pred_class_id': cl})
622
623
        # add gt boxes back to new output list.
624
        out_patient_results_list[bix].extend([box for box in b if box['box_type'] == 'gt'])
625
626
    return [out_patient_results_list, pid]
627
628
629
630
def merge_2D_to_3D_preds_per_patient(inputs):
631
    """
632
    wrapper around 2Dto3D merging operation. Processes a single patient. Takes 2D patient results (slices in batch dimension)
633
    and returns 3D patient results (dummy batch dimension of 1). Applies an adaption of Non-Maximum Surpression
634
    (Detailed methodology is described in nms_2to3D).
635
    :return. results_dict_boxes: list over batch elements (1 in 3D). each element is a list over boxes, where each box is
636
                                 one dictionary: [[box_0, ...], [box_n,...]].
637
    :return. pid: string. patient id.
638
    """
639
    in_patient_results_list, pid, class_dict, merge_3D_iou = inputs
640
    out_patient_results_list = []
641
642
    for cl in list(class_dict.keys()):
643
        boxes, slice_ids = [], []
644
        # collect box predictions over batch dimension (slices) and store slice info as slice_ids.
645
        for bix, b in enumerate(in_patient_results_list):
646
            det_boxes = [(ix, box) for ix, box in enumerate(b) if
647
                     (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)]
648
            boxes += det_boxes
649
            slice_ids += [bix] * len(det_boxes)
650
651
        box_coords = np.array([b[1]['box_coords'] for b in boxes])
652
        box_scores = np.array([b[1]['box_score'] for b in boxes])
653
        slice_ids = np.array(slice_ids)
654
655
        if 0 not in box_scores.shape:
656
            keep_ix, keep_z = nms_2to3D(
657
                np.concatenate((box_coords, box_scores[:, None], slice_ids[:, None]), axis=1), merge_3D_iou)
658
        else:
659
            keep_ix, keep_z = [], []
660
661
        # store kept predictions in new results list and add corresponding z-dimension info to coordinates.
662
        for kix, kz in zip(keep_ix, keep_z):
663
            out_patient_results_list.append({'box_type': 'det', 'box_coords': list(box_coords[kix]) + kz,
664
                                             'box_score': box_scores[kix], 'box_pred_class_id': cl})
665
666
    gt_boxes = [box for b in in_patient_results_list for box in b if box['box_type'] == 'gt']
667
    if len(gt_boxes) > 0:
668
        assert np.all([len(box["box_coords"]) == 6 for box in gt_boxes]), "expanded preds to 3D but GT is 2D."
669
    out_patient_results_list += gt_boxes
670
671
    # add dummy batch dimension 1 for 3D.
672
    return [[out_patient_results_list], pid]
673
674
675
676
def weighted_box_clustering(dets, box_patch_id, thresh, n_ens):
677
    """
678
    consolidates overlapping predictions resulting from patch overlaps, test data augmentations and temporal ensembling.
679
    clusters predictions together with iou > thresh (like in NMS). Output score and coordinate for one cluster are the
680
    average weighted by individual patch center factors (how trustworthy is this candidate measured by how centered
681
    its position the patch is) and the size of the corresponding box.
682
    The number of expected predictions at a position is n_data_aug * n_temp_ens * n_overlaps_at_position
683
    (1 prediction per unique patch). Missing predictions at a cluster position are defined as the number of unique
684
    patches in the cluster, which did not contribute any predict any boxes.
685
    :param dets: (n_dets, (y1, x1, y2, x2, (z1), (z2), scores, box_pc_facts, box_n_ovs)
686
    :param thresh: threshold for iou_matching.
687
    :param n_ens: number of models, that are ensembled. (-> number of expected predicitions per position)
688
    :return: keep_scores: (n_keep)  new scores of boxes to be kept.
689
    :return: keep_coords: (n_keep, (y1, x1, y2, x2, (z1), (z2)) new coordinates of boxes to be kept.
690
    """
691
    dim = 2 if dets.shape[1] == 7 else 3
692
    y1 = dets[:, 0]
693
    x1 = dets[:, 1]
694
    y2 = dets[:, 2]
695
    x2 = dets[:, 3]
696
    scores = dets[:, -3]
697
    box_pc_facts = dets[:, -2]
698
    box_n_ovs = dets[:, -1]
699
700
    areas = (y2 - y1 + 1) * (x2 - x1 + 1)
701
702
    if dim == 3:
703
        z1 = dets[:, 4]
704
        z2 = dets[:, 5]
705
        areas *= (z2 - z1 + 1)
706
707
    # order is the sorted index.  maps order to index o[1] = 24 (rank1, ix 24)
708
    order = scores.argsort()[::-1]
709
710
    keep = []
711
    keep_scores = []
712
    keep_coords = []
713
714
    while order.size > 0:
715
        i = order[0]  # higehst scoring element
716
        xx1 = np.maximum(x1[i], x1[order])
717
        yy1 = np.maximum(y1[i], y1[order])
718
        xx2 = np.minimum(x2[i], x2[order])
719
        yy2 = np.minimum(y2[i], y2[order])
720
721
        w = np.maximum(0.0, xx2 - xx1 + 1)
722
        h = np.maximum(0.0, yy2 - yy1 + 1)
723
        inter = w * h
724
725
        if dim == 3:
726
            zz1 = np.maximum(z1[i], z1[order])
727
            zz2 = np.minimum(z2[i], z2[order])
728
            d = np.maximum(0.0, zz2 - zz1 + 1)
729
            inter *= d
730
731
        # overall between currently highest scoring box and all boxes.
732
        ovr = inter / (areas[i] + areas[order] - inter)
733
734
        # get all the predictions that match the current box to build one cluster.
735
        matches = np.argwhere(ovr > thresh)
736
737
        match_n_ovs = box_n_ovs[order[matches]]
738
        match_pc_facts = box_pc_facts[order[matches]]
739
        match_patch_id = box_patch_id[order[matches]]
740
        match_ov_facts = ovr[matches]
741
        match_areas = areas[order[matches]]
742
        match_scores = scores[order[matches]]
743
744
        # weight all socres in cluster by patch factors, and size.
745
        match_score_weights = match_ov_facts * match_areas * match_pc_facts
746
        match_scores *= match_score_weights
747
748
        # for the weigted average, scores have to be divided by the number of total expected preds at the position
749
        # of the current cluster. 1 Prediction per patch is expected. therefore, the number of ensembled models is
750
        # multiplied by the mean overlaps of  patches at this position (boxes of the cluster might partly be
751
        # in areas of different overlaps).
752
        n_expected_preds = n_ens * np.mean(match_n_ovs)
753
754
        # the number of missing predictions is obtained as the number of patches,
755
        # which did not contribute any prediction to the current cluster.
756
        n_missing_preds = np.max((0, n_expected_preds - np.unique(match_patch_id).shape[0]))
757
758
        # missing preds are given the mean weighting
759
        # (expected prediction is the mean over all predictions in cluster).
760
        denom = np.sum(match_score_weights) + n_missing_preds * np.mean(match_score_weights)
761
762
        # compute weighted average score for the cluster
763
        avg_score = np.sum(match_scores) / denom
764
765
        # compute weighted average of coordinates for the cluster. now only take existing
766
        # predictions into account.
767
        avg_coords = [np.sum(y1[order[matches]] * match_scores) / np.sum(match_scores),
768
                      np.sum(x1[order[matches]] * match_scores) / np.sum(match_scores),
769
                      np.sum(y2[order[matches]] * match_scores) / np.sum(match_scores),
770
                      np.sum(x2[order[matches]] * match_scores) / np.sum(match_scores)]
771
        if dim == 3:
772
            avg_coords.append(np.sum(z1[order[matches]] * match_scores) / np.sum(match_scores))
773
            avg_coords.append(np.sum(z2[order[matches]] * match_scores) / np.sum(match_scores))
774
775
        # some clusters might have very low scores due to high amounts of missing predictions.
776
        # filter out the with a conservative threshold, to speed up evaluation.
777
        if avg_score > 0.01:
778
            keep_scores.append(avg_score)
779
            keep_coords.append(avg_coords)
780
781
        # get index of all elements that were not matched and discard all others.
782
        inds = np.where(ovr <= thresh)[0]
783
        order = order[inds]
784
785
    return keep_scores, keep_coords
786
787
788
789
def nms_2to3D(dets, thresh):
790
    """
791
    Merges 2D boxes to 3D cubes. Therefore, boxes of all slices are projected into one slices. An adaptation of Non-maximum surpression
792
    is applied, where clusters are found (like in NMS) with an extra constrained, that surpressed boxes have to have 'connected'
793
    z-coordinates w.r.t the core slice (cluster center, highest scoring box). 'connected' z-coordinates are determined
794
    as the z-coordinates with predictions until the first coordinate, where no prediction was found.
795
796
    example: a cluster of predictions was found overlap > iou thresh in xy (like NMS). The z-coordinate of the highest
797
    scoring box is 50. Other predictions have 23, 46, 48, 49, 51, 52, 53, 56, 57.
798
    Only the coordinates connected with 50 are clustered to one cube: 48, 49, 51, 52, 53. (46 not because nothing was
799
    found in 47, so 47 is a 'hole', which interrupts the connection). Only the boxes corresponding to these coordinates
800
    are surpressed. All others are kept for building of further clusters.
801
802
    This algorithm works better with a certain min_confidence of predictions, because low confidence (e.g. noisy/cluttery)
803
    predictions can break the relatively strong assumption of defining cubes' z-boundaries at the first 'hole' in the cluster.
804
805
    :param dets: (n_detections, (y1, x1, y2, x2, scores, slice_id)
806
    :param thresh: iou matchin threshold (like in NMS).
807
    :return: keep: (n_keep) 1D tensor of indices to be kept.
808
    :return: keep_z: (n_keep, [z1, z2]) z-coordinates to be added to boxes, which are kept in order to form cubes.
809
    """
810
    y1 = dets[:, 0]
811
    x1 = dets[:, 1]
812
    y2 = dets[:, 2]
813
    x2 = dets[:, 3]
814
    scores = dets[:, -2]
815
    slice_id = dets[:, -1]
816
817
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
818
    order = scores.argsort()[::-1]
819
820
    keep = []
821
    keep_z = []
822
823
    while order.size > 0:  # order is the sorted index.  maps order to index o[1] = 24 (rank1, ix 24)
824
        i = order[0]  # pop higehst scoring element
825
        xx1 = np.maximum(x1[i], x1[order])
826
        yy1 = np.maximum(y1[i], y1[order])
827
        xx2 = np.minimum(x2[i], x2[order])
828
        yy2 = np.minimum(y2[i], y2[order])
829
830
        w = np.maximum(0.0, xx2 - xx1 + 1)
831
        h = np.maximum(0.0, yy2 - yy1 + 1)
832
        inter = w * h
833
834
        ovr = inter / (areas[i] + areas[order] - inter)
835
        matches = np.argwhere(ovr > thresh)  # get all the elements that match the current box and have a lower score
836
837
        slice_ids = slice_id[order[matches]]
838
        core_slice = slice_id[int(i)]
839
        upper_wholes = [ii for ii in np.arange(core_slice, np.max(slice_ids)) if ii not in slice_ids]
840
        lower_wholes = [ii for ii in np.arange(np.min(slice_ids), core_slice) if ii not in slice_ids]
841
        max_valid_slice_id = np.min(upper_wholes) if len(upper_wholes) > 0 else np.max(slice_ids)
842
        min_valid_slice_id = np.max(lower_wholes) if len(lower_wholes) > 0 else np.min(slice_ids)
843
        z_matches = matches[(slice_ids <= max_valid_slice_id) & (slice_ids >= min_valid_slice_id)]
844
845
        z1 = np.min(slice_id[order[z_matches]]) - 1
846
        z2 = np.max(slice_id[order[z_matches]]) + 1
847
848
        keep.append(i)
849
        keep_z.append([z1, z2])
850
        order = np.delete(order, z_matches, axis=0)
851
852
    return keep, keep_z
853
854
855
856
def get_mirrored_patch_crops(patch_crops, org_img_shape):
857
    """
858
    apply 3 mirrror transformations (x-axis, y-axis, x&y-axis)
859
    to given patch crop coordinates and return the transformed coordinates.
860
    Handles 2D and 3D coordinates.
861
    :param patch_crops: list of crops: each element is a list of coordinates for given crop [[y1, x1, ...], [y1, x1, ..]]
862
    :param org_img_shape: shape of patient volume used as world coordinates.
863
    :return: list of mirrored patch crops: lenght=3. each element is a list of transformed patch crops.
864
    """
865
    mirrored_patch_crops = []
866
867
    # y-axis transform.
868
    mirrored_patch_crops.append([[org_img_shape[2] - ii[1],
869
                                  org_img_shape[2] - ii[0],
870
                                  ii[2], ii[3]] if len(ii) == 4 else
871
                                 [org_img_shape[2] - ii[1],
872
                                  org_img_shape[2] - ii[0],
873
                                  ii[2], ii[3], ii[4], ii[5]] for ii in patch_crops])
874
875
    # x-axis transform.
876
    mirrored_patch_crops.append([[ii[0], ii[1],
877
                                  org_img_shape[3] - ii[3],
878
                                  org_img_shape[3] - ii[2]] if len(ii) == 4 else
879
                                 [ii[0], ii[1],
880
                                  org_img_shape[3] - ii[3],
881
                                  org_img_shape[3] - ii[2],
882
                                  ii[4], ii[5]] for ii in patch_crops])
883
884
    # y-axis and x-axis transform.
885
    mirrored_patch_crops.append([[org_img_shape[2] - ii[1],
886
                                  org_img_shape[2] - ii[0],
887
                                  org_img_shape[3] - ii[3],
888
                                  org_img_shape[3] - ii[2]] if len(ii) == 4 else
889
                                 [org_img_shape[2] - ii[1],
890
                                  org_img_shape[2] - ii[0],
891
                                  org_img_shape[3] - ii[3],
892
                                  org_img_shape[3] - ii[2],
893
                                  ii[4], ii[5]] for ii in patch_crops])
894
895
    return mirrored_patch_crops
896
897
898