a b/inference.py
1
"""
2
Inference module supporting whole slide images using TF Serving.
3
Eric Wang
4
Email: eric.wang@thorough.ai
5
"""
6
import os
7
import shutil
8
import numpy as np
9
from utils.tf_serving import TFServing
10
from utils.slide import Slide
11
from utils import config
12
from utils.libs import write, generate_effective_regions, generate_overlap_tile, \
13
    post_processing, concat_patches
14
15
16
class Inference:
17
    def __init__(self, data_dir, data_list, class_num, result_dir, use_level):
18
        """
19
        This is the main inference module for the sake of easy to call.
20
        :param data_dir: The directory storing the while image slides.
21
        :param data_list: The text file indicating the slide names.
22
        :param class_num: Number of predicted classes.
23
        :param result_dir: Where to put the predicted results.
24
        :param use_level: Which slide size we want to analyze, 0 for 40x, 1 for 20x, etc.
25
        """
26
        if data_dir.endswith('/'):
27
            self.data_dir = data_dir
28
        else:
29
            self.data_dir = data_dir + '/'
30
        self.data_list = data_list
31
        self.class_num = class_num
32
        if result_dir.endswith('/'):
33
            self.result_dir = result_dir
34
        else:
35
            self.result_dir = result_dir + '/'
36
        if not os.path.exists(self.result_dir):
37
            os.mkdir(self.result_dir)
38
        self.use_level = use_level
39
        self.config = config
40
41
    @staticmethod
42
    def _infer(tfs_client, image):
43
        """
44
        Inference for an image patch using TF Serving.
45
        :param tfs_client: TF Serving client.
46
        :param image: The image patch.
47
        :return: Predicted heatmap.
48
        """
49
        try:
50
            prediction = tfs_client.predict(image, config.MODEL_NAME)
51
        except Exception as e:
52
            print('TF_SERVING_HOST: {}'.format(config.TF_SERVING_HOST))
53
            print(e)
54
            raise
55
        else:
56
            return prediction
57
58
    def run(self):
59
        """
60
        Proceeds the inference procedure.
61
        """
62
        inference_list = open(self.data_list).readlines()
63
        tfs_client = TFServing(config.TF_SERVING_HOST, config.TF_SERVING_PORT)
64
        for item in inference_list:
65
            image_name, image_suffix = item.split('\n')[0].split('/')[-1].split('.')
66
            print('[INFO] Analyzing: ' + self.data_dir + item.split('\n')[0])
67
            if not image_suffix in self.config.format_mapping.keys():
68
                print('[ERROR] File ' + item + ' format not supported yet.')
69
                continue
70
            image_handle = Slide(self.data_dir + item.split('\n')[0])
71
            image_dimensions = image_handle.level_dimensions[self.use_level]
72
            regions = generate_effective_regions(image_dimensions)
73
            index = 0
74
            region_num = len(regions)
75
            temp_dir = self.config.TEMP_DIR + image_name + '/'
76
            if not os.path.exists(temp_dir):
77
                os.makedirs(temp_dir)
78
            for region in regions:
79
                shifted_region, clip_region = generate_overlap_tile(region, image_dimensions)
80
                index += 1
81
                if index % 1 == 0:
82
                    print('[INFO]  Progress: ' + str(index) + ' / ' + str(region_num))
83
                input_image = np.array(image_handle.read_region(
84
                    location=(int(shifted_region[0]),
85
                              int(shifted_region[1])),
86
                    level=self.use_level, size=(self.config.PATCH_SIZE, self.config.PATCH_SIZE)))[:, :, 0: 3]
87
                prediction_result = self._infer(tfs_client, input_image)
88
                prediction_result = prediction_result[clip_region[0]: (self.config.CENTER_SIZE + clip_region[0]),
89
                                    clip_region[1]: (self.config.CENTER_SIZE + clip_region[1])]
90
                prediction_result = prediction_result[region[2]:(region[4] + 1), region[3]:(region[5] + 1)]
91
                if self.config.DO_POST_PROCESSING:
92
                    prediction_result = post_processing(prediction_result)
93
                write(temp_dir + image_name + '_' + str(region[0]) + '_' + str(region[1])
94
                      + '_prediction.png', prediction_result, self.class_num)
95
            print('[INFO] Postprocessing...')
96
            full_prediction = concat_patches(temp_dir, image_name)
97
            write(self.result_dir +
98
                  '_'.join([image_name, 'prediction_thumbnail']) + '.png', full_prediction, color_map=False)
99
            if not self.config.KEEP_TEMP:
100
                shutil.rmtree(temp_dir)
101
            print('[INFO] Prediction saved to ' + self.result_dir + '_'.join(
102
                [image_name, 'prediction_thumbnail']) + '.png')