|
a |
|
b/inference.py |
|
|
1 |
""" |
|
|
2 |
Inference module supporting whole slide images using TF Serving. |
|
|
3 |
Eric Wang |
|
|
4 |
Email: eric.wang@thorough.ai |
|
|
5 |
""" |
|
|
6 |
import os |
|
|
7 |
import shutil |
|
|
8 |
import numpy as np |
|
|
9 |
from utils.tf_serving import TFServing |
|
|
10 |
from utils.slide import Slide |
|
|
11 |
from utils import config |
|
|
12 |
from utils.libs import write, generate_effective_regions, generate_overlap_tile, \ |
|
|
13 |
post_processing, concat_patches |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
class Inference: |
|
|
17 |
def __init__(self, data_dir, data_list, class_num, result_dir, use_level): |
|
|
18 |
""" |
|
|
19 |
This is the main inference module for the sake of easy to call. |
|
|
20 |
:param data_dir: The directory storing the while image slides. |
|
|
21 |
:param data_list: The text file indicating the slide names. |
|
|
22 |
:param class_num: Number of predicted classes. |
|
|
23 |
:param result_dir: Where to put the predicted results. |
|
|
24 |
:param use_level: Which slide size we want to analyze, 0 for 40x, 1 for 20x, etc. |
|
|
25 |
""" |
|
|
26 |
if data_dir.endswith('/'): |
|
|
27 |
self.data_dir = data_dir |
|
|
28 |
else: |
|
|
29 |
self.data_dir = data_dir + '/' |
|
|
30 |
self.data_list = data_list |
|
|
31 |
self.class_num = class_num |
|
|
32 |
if result_dir.endswith('/'): |
|
|
33 |
self.result_dir = result_dir |
|
|
34 |
else: |
|
|
35 |
self.result_dir = result_dir + '/' |
|
|
36 |
if not os.path.exists(self.result_dir): |
|
|
37 |
os.mkdir(self.result_dir) |
|
|
38 |
self.use_level = use_level |
|
|
39 |
self.config = config |
|
|
40 |
|
|
|
41 |
@staticmethod |
|
|
42 |
def _infer(tfs_client, image): |
|
|
43 |
""" |
|
|
44 |
Inference for an image patch using TF Serving. |
|
|
45 |
:param tfs_client: TF Serving client. |
|
|
46 |
:param image: The image patch. |
|
|
47 |
:return: Predicted heatmap. |
|
|
48 |
""" |
|
|
49 |
try: |
|
|
50 |
prediction = tfs_client.predict(image, config.MODEL_NAME) |
|
|
51 |
except Exception as e: |
|
|
52 |
print('TF_SERVING_HOST: {}'.format(config.TF_SERVING_HOST)) |
|
|
53 |
print(e) |
|
|
54 |
raise |
|
|
55 |
else: |
|
|
56 |
return prediction |
|
|
57 |
|
|
|
58 |
def run(self): |
|
|
59 |
""" |
|
|
60 |
Proceeds the inference procedure. |
|
|
61 |
""" |
|
|
62 |
inference_list = open(self.data_list).readlines() |
|
|
63 |
tfs_client = TFServing(config.TF_SERVING_HOST, config.TF_SERVING_PORT) |
|
|
64 |
for item in inference_list: |
|
|
65 |
image_name, image_suffix = item.split('\n')[0].split('/')[-1].split('.') |
|
|
66 |
print('[INFO] Analyzing: ' + self.data_dir + item.split('\n')[0]) |
|
|
67 |
if not image_suffix in self.config.format_mapping.keys(): |
|
|
68 |
print('[ERROR] File ' + item + ' format not supported yet.') |
|
|
69 |
continue |
|
|
70 |
image_handle = Slide(self.data_dir + item.split('\n')[0]) |
|
|
71 |
image_dimensions = image_handle.level_dimensions[self.use_level] |
|
|
72 |
regions = generate_effective_regions(image_dimensions) |
|
|
73 |
index = 0 |
|
|
74 |
region_num = len(regions) |
|
|
75 |
temp_dir = self.config.TEMP_DIR + image_name + '/' |
|
|
76 |
if not os.path.exists(temp_dir): |
|
|
77 |
os.makedirs(temp_dir) |
|
|
78 |
for region in regions: |
|
|
79 |
shifted_region, clip_region = generate_overlap_tile(region, image_dimensions) |
|
|
80 |
index += 1 |
|
|
81 |
if index % 1 == 0: |
|
|
82 |
print('[INFO] Progress: ' + str(index) + ' / ' + str(region_num)) |
|
|
83 |
input_image = np.array(image_handle.read_region( |
|
|
84 |
location=(int(shifted_region[0]), |
|
|
85 |
int(shifted_region[1])), |
|
|
86 |
level=self.use_level, size=(self.config.PATCH_SIZE, self.config.PATCH_SIZE)))[:, :, 0: 3] |
|
|
87 |
prediction_result = self._infer(tfs_client, input_image) |
|
|
88 |
prediction_result = prediction_result[clip_region[0]: (self.config.CENTER_SIZE + clip_region[0]), |
|
|
89 |
clip_region[1]: (self.config.CENTER_SIZE + clip_region[1])] |
|
|
90 |
prediction_result = prediction_result[region[2]:(region[4] + 1), region[3]:(region[5] + 1)] |
|
|
91 |
if self.config.DO_POST_PROCESSING: |
|
|
92 |
prediction_result = post_processing(prediction_result) |
|
|
93 |
write(temp_dir + image_name + '_' + str(region[0]) + '_' + str(region[1]) |
|
|
94 |
+ '_prediction.png', prediction_result, self.class_num) |
|
|
95 |
print('[INFO] Postprocessing...') |
|
|
96 |
full_prediction = concat_patches(temp_dir, image_name) |
|
|
97 |
write(self.result_dir + |
|
|
98 |
'_'.join([image_name, 'prediction_thumbnail']) + '.png', full_prediction, color_map=False) |
|
|
99 |
if not self.config.KEEP_TEMP: |
|
|
100 |
shutil.rmtree(temp_dir) |
|
|
101 |
print('[INFO] Prediction saved to ' + self.result_dir + '_'.join( |
|
|
102 |
[image_name, 'prediction_thumbnail']) + '.png') |