|
a |
|
b/inference_segmentation.py |
|
|
1 |
|
|
|
2 |
# Mask R-CNN model for lesion segmentation in chest CT scans |
|
|
3 |
# Torchvision detection package is locally re-implemented |
|
|
4 |
# by Alex Ter-Sarkisov@City, University of London |
|
|
5 |
# alex.ter-sarkisov@city.ac.uk |
|
|
6 |
# 2020 |
|
|
7 |
import argparse |
|
|
8 |
import time |
|
|
9 |
import pickle |
|
|
10 |
import copy |
|
|
11 |
import torch |
|
|
12 |
import torchvision |
|
|
13 |
import numpy as np |
|
|
14 |
import os |
|
|
15 |
import cv2 |
|
|
16 |
import models.mask_net as mask_net |
|
|
17 |
from models.mask_net.rpn_segmentation import AnchorGenerator |
|
|
18 |
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead |
|
|
19 |
from models.mask_net.covid_mask_net import MaskRCNNHeads, MaskRCNNPredictor |
|
|
20 |
from torchvision import transforms |
|
|
21 |
from torch.utils import data |
|
|
22 |
import torch.utils as utils |
|
|
23 |
import datasets.dataset_segmentation as dataset |
|
|
24 |
from PIL import Image as PILImage |
|
|
25 |
import matplotlib.pyplot as plt |
|
|
26 |
import matplotlib.patches as patches |
|
|
27 |
from matplotlib.patches import Rectangle |
|
|
28 |
import utils |
|
|
29 |
import config_segmentation as config |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
# main method |
|
|
33 |
def main(config, step): |
|
|
34 |
devices = ['cpu', 'cuda'] |
|
|
35 |
mask_classes = ['both', 'ggo', 'merge'] |
|
|
36 |
backbones = ['resnet50', 'resnet34', 'resnet18'] |
|
|
37 |
truncation_levels = ['0','1','2'] |
|
|
38 |
assert config.device in devices |
|
|
39 |
assert config.backbone_name in backbones |
|
|
40 |
assert config.truncation in truncation_levels |
|
|
41 |
|
|
|
42 |
assert config.mask_type in mask_classes |
|
|
43 |
if config.device == 'cuda' and torch.cuda.is_available(): |
|
|
44 |
device = torch.device('cuda') |
|
|
45 |
else: |
|
|
46 |
device = torch.device('cpu') |
|
|
47 |
# get the configuration |
|
|
48 |
# get the thresholds |
|
|
49 |
confidence_threshold, mask_threshold, save_dir, data_dir, img_dir, gt_dir, mask_type, rpn_nms, roi_nms, backbone_name, truncation \ |
|
|
50 |
= config.confidence_th, config.mask_logits_th, config.save_dir, config.test_data_dir, config.test_imgs_dir, \ |
|
|
51 |
config.gt_dir, config.mask_type, config.rpn_nms_th, config.roi_nms_th, config.backbone_name, config.truncation |
|
|
52 |
|
|
|
53 |
if mask_type == "both": |
|
|
54 |
n_c = 3 |
|
|
55 |
else: |
|
|
56 |
n_c = 2 |
|
|
57 |
ckpt = torch.load(config.ckpt, map_location=device) |
|
|
58 |
|
|
|
59 |
model_name = None |
|
|
60 |
if 'model_name' in ckpt.keys(): |
|
|
61 |
model_name = ckpt['model_name'] |
|
|
62 |
sizes = ckpt['anchor_generator'].sizes |
|
|
63 |
aspect_ratios = ckpt['anchor_generator'].aspect_ratios |
|
|
64 |
anchor_generator = AnchorGenerator(sizes, aspect_ratios) |
|
|
65 |
print("Anchors: ", anchor_generator.sizes, anchor_generator.aspect_ratios) |
|
|
66 |
|
|
|
67 |
# create modules |
|
|
68 |
# this assumes FPN with 256 channels |
|
|
69 |
box_head = TwoMLPHead(in_channels=7 * 7 * 256, representation_size=128) |
|
|
70 |
if backbone_name == 'resnet50': |
|
|
71 |
maskrcnn_heads = None |
|
|
72 |
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c) |
|
|
73 |
mask_predictor = MaskRCNNPredictor(in_channels=256, dim_reduced=256, num_classes=n_c) |
|
|
74 |
else: |
|
|
75 |
#Backbone->FPN->boxhead->boxpredictor |
|
|
76 |
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c) |
|
|
77 |
maskrcnn_heads = MaskRCNNHeads(in_channels=256, layers=(128,), dilation=1) |
|
|
78 |
mask_predictor = MaskRCNNPredictor(in_channels=128, dim_reduced=128, num_classes=n_c) |
|
|
79 |
|
|
|
80 |
# keyword arguments |
|
|
81 |
maskrcnn_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': 100, |
|
|
82 |
'box_nms_thresh': roi_nms, 'box_score_thresh': confidence_threshold, 'rpn_nms_thresh': rpn_nms, |
|
|
83 |
'box_head': box_head, 'rpn_anchor_generator': anchor_generator, 'mask_head':maskrcnn_heads, |
|
|
84 |
'mask_predictor': mask_predictor, 'box_predictor': box_predictor} |
|
|
85 |
|
|
|
86 |
# Instantiate the segmentation model |
|
|
87 |
maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name, truncation, pretrained_backbone=False, **maskrcnn_args) |
|
|
88 |
# Load weights |
|
|
89 |
maskrcnn_model.load_state_dict(ckpt['model_weights']) |
|
|
90 |
# Set to evaluation mode |
|
|
91 |
print(maskrcnn_model) |
|
|
92 |
maskrcnn_model.eval().to(device) |
|
|
93 |
|
|
|
94 |
start_time = time.time() |
|
|
95 |
# get the correct masks and mask colors |
|
|
96 |
if mask_type == "ggo": |
|
|
97 |
ct_classes = {0: '__bgr', 1: 'GGO'} |
|
|
98 |
ct_colors = {1: 'red', 'mask_cols': np.array([[255, 0, 0]])} |
|
|
99 |
elif mask_type == "merge": |
|
|
100 |
ct_classes = {0: '__bgr', 1: 'Lesion'} |
|
|
101 |
ct_colors = {1: 'red', 'mask_cols': np.array([[255, 0, 0]])} |
|
|
102 |
elif mask_type == "both": |
|
|
103 |
ct_classes = {0: '__bgr', 1: 'GGO', 2: 'CL'} |
|
|
104 |
ct_colors = {1: 'red', 2: 'blue', 'mask_cols': np.array([[255, 0, 0], [0, 0, 255]])} |
|
|
105 |
|
|
|
106 |
if not save_dir in os.listdir('.'): |
|
|
107 |
os.mkdir(save_dir) |
|
|
108 |
|
|
|
109 |
# model name from config, not checkpoint |
|
|
110 |
if model_name is None: |
|
|
111 |
model_name = "maskrcnn_segmentation" |
|
|
112 |
elif model_name is not None and config.model_name != model_name: |
|
|
113 |
print("Using model name from the config.") |
|
|
114 |
model_name = config.model_name |
|
|
115 |
|
|
|
116 |
# run the inference with provided hyperparameters |
|
|
117 |
test_ims = os.listdir(os.path.join(data_dir, img_dir)) |
|
|
118 |
for j, ims in enumerate(test_ims): |
|
|
119 |
step(os.path.join(os.path.join(data_dir, img_dir), ims), device, maskrcnn_model, model_name, |
|
|
120 |
confidence_threshold, mask_threshold, save_dir, ct_classes, ct_colors, j) |
|
|
121 |
end_time = time.time() |
|
|
122 |
print("Inference took {0:.1f} seconds".format(end_time - start_time)) |
|
|
123 |
|
|
|
124 |
|
|
|
125 |
def test_step(image, device, model, model_name, theta_conf, theta_mask, save_dir, cls, cols, num): |
|
|
126 |
im = PILImage.open(image) |
|
|
127 |
# convert image to RGB, remove the alpha channel |
|
|
128 |
if im.mode != 'RGB': |
|
|
129 |
im = im.convert(mode='RGB') |
|
|
130 |
img = np.array(im) |
|
|
131 |
# copy image to make background for plotting |
|
|
132 |
bgr_img = copy.deepcopy(img) |
|
|
133 |
if img.shape[2] > 3: |
|
|
134 |
img = img[:, :, :3] |
|
|
135 |
# torchvision transforms, the rest Mask R-CNN does internally |
|
|
136 |
t_ = transforms.Compose([ |
|
|
137 |
transforms.ToPILImage(), |
|
|
138 |
transforms.ToTensor()]) |
|
|
139 |
img = t_(img).to(device) |
|
|
140 |
out = model([img]) |
|
|
141 |
# scores + bounding boxes + labels + masks |
|
|
142 |
scores = out[0]['scores'] |
|
|
143 |
bboxes = out[0]['boxes'] |
|
|
144 |
classes = out[0]['labels'] |
|
|
145 |
mask = out[0]['masks'] |
|
|
146 |
# this is the array for all masks |
|
|
147 |
best_scores = scores[scores > theta_conf] |
|
|
148 |
# Are there any detections with confidence above the threshold? |
|
|
149 |
if len(best_scores): |
|
|
150 |
best_idx = np.where(scores > theta_conf) |
|
|
151 |
best_bboxes = bboxes[best_idx] |
|
|
152 |
best_classes = classes[best_idx] |
|
|
153 |
best_masks = mask[best_idx] |
|
|
154 |
print('bm', best_masks.shape) |
|
|
155 |
mask_array = np.zeros([best_masks[0].shape[1], best_masks[0].shape[2], 3], dtype=np.uint8) |
|
|
156 |
fig, ax = plt.subplots(1, 1) |
|
|
157 |
fig.set_size_inches(12, 6) |
|
|
158 |
ax.axis("off") |
|
|
159 |
# plot predictions |
|
|
160 |
for idx, dets in enumerate(best_bboxes): |
|
|
161 |
found_masks = best_masks[idx][0].detach().clone().to(device).numpy() |
|
|
162 |
pred_class = best_classes[idx].item() |
|
|
163 |
pred_col_n = cols[pred_class] |
|
|
164 |
pred_class_txt = cls[pred_class] |
|
|
165 |
pred_col = cols['mask_cols'][pred_class - 1] |
|
|
166 |
mask_array[found_masks > theta_mask] = pred_col |
|
|
167 |
rect = Rectangle((dets[0], dets[1]), dets[2] - dets[0], dets[3] - dets[1], linewidth=1, |
|
|
168 |
edgecolor=pred_col_n, facecolor='none', linestyle="--") |
|
|
169 |
ax.text(dets[0] + 40, dets[1], '{0:}'.format(pred_class_txt), fontsize=10, color=pred_col_n) |
|
|
170 |
ax.text(dets[0], dets[1], '{0:.2f}'.format(best_scores[idx]), fontsize=10, color=pred_col_n) |
|
|
171 |
ax.add_patch(rect) |
|
|
172 |
|
|
|
173 |
added_image = cv2.addWeighted(bgr_img, 0.5, mask_array, 0.75, gamma=0) |
|
|
174 |
ax.imshow(added_image) |
|
|
175 |
fig.savefig(os.path.join(save_dir, model_name + "_" + str(num) + ".png"), |
|
|
176 |
bbox_inches='tight', pad_inches=0.0) |
|
|
177 |
|
|
|
178 |
else: |
|
|
179 |
print("No detections") |
|
|
180 |
|
|
|
181 |
# run the inference |
|
|
182 |
if __name__ == '__main__': |
|
|
183 |
config_test = config.get_config_pars("test") |
|
|
184 |
main(config_test, test_step) |