# Mask R-CNN model for lesion segmentation in chest CT scans
# Torchvision detection package is locally re-implemented
# by Alex Ter-Sarkisov@City, University of London
# alex.ter-sarkisov@city.ac.uk
# 2020
import argparse
import time
import pickle
import copy
import torch
import torchvision
import numpy as np
import os
import cv2
import models.mask_net as mask_net
from models.mask_net.rpn_segmentation import AnchorGenerator
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from models.mask_net.covid_mask_net import MaskRCNNHeads, MaskRCNNPredictor
from torchvision import transforms
from torch.utils import data
import torch.utils as utils
import datasets.dataset_segmentation as dataset
from PIL import Image as PILImage
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
import utils
import config_segmentation as config
# main method
def main(config, step):
devices = ['cpu', 'cuda']
mask_classes = ['both', 'ggo', 'merge']
backbones = ['resnet50', 'resnet34', 'resnet18']
truncation_levels = ['0','1','2']
assert config.device in devices
assert config.backbone_name in backbones
assert config.truncation in truncation_levels
assert config.mask_type in mask_classes
if config.device == 'cuda' and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
# get the configuration
# get the thresholds
confidence_threshold, mask_threshold, save_dir, data_dir, img_dir, gt_dir, mask_type, rpn_nms, roi_nms, backbone_name, truncation \
= config.confidence_th, config.mask_logits_th, config.save_dir, config.test_data_dir, config.test_imgs_dir, \
config.gt_dir, config.mask_type, config.rpn_nms_th, config.roi_nms_th, config.backbone_name, config.truncation
if mask_type == "both":
n_c = 3
else:
n_c = 2
ckpt = torch.load(config.ckpt, map_location=device)
model_name = None
if 'model_name' in ckpt.keys():
model_name = ckpt['model_name']
sizes = ckpt['anchor_generator'].sizes
aspect_ratios = ckpt['anchor_generator'].aspect_ratios
anchor_generator = AnchorGenerator(sizes, aspect_ratios)
print("Anchors: ", anchor_generator.sizes, anchor_generator.aspect_ratios)
# create modules
# this assumes FPN with 256 channels
box_head = TwoMLPHead(in_channels=7 * 7 * 256, representation_size=128)
if backbone_name == 'resnet50':
maskrcnn_heads = None
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
mask_predictor = MaskRCNNPredictor(in_channels=256, dim_reduced=256, num_classes=n_c)
else:
#Backbone->FPN->boxhead->boxpredictor
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
maskrcnn_heads = MaskRCNNHeads(in_channels=256, layers=(128,), dilation=1)
mask_predictor = MaskRCNNPredictor(in_channels=128, dim_reduced=128, num_classes=n_c)
# keyword arguments
maskrcnn_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': 100,
'box_nms_thresh': roi_nms, 'box_score_thresh': confidence_threshold, 'rpn_nms_thresh': rpn_nms,
'box_head': box_head, 'rpn_anchor_generator': anchor_generator, 'mask_head':maskrcnn_heads,
'mask_predictor': mask_predictor, 'box_predictor': box_predictor}
# Instantiate the segmentation model
maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name, truncation, pretrained_backbone=False, **maskrcnn_args)
# Load weights
maskrcnn_model.load_state_dict(ckpt['model_weights'])
# Set to evaluation mode
print(maskrcnn_model)
maskrcnn_model.eval().to(device)
start_time = time.time()
# get the correct masks and mask colors
if mask_type == "ggo":
ct_classes = {0: '__bgr', 1: 'GGO'}
ct_colors = {1: 'red', 'mask_cols': np.array([[255, 0, 0]])}
elif mask_type == "merge":
ct_classes = {0: '__bgr', 1: 'Lesion'}
ct_colors = {1: 'red', 'mask_cols': np.array([[255, 0, 0]])}
elif mask_type == "both":
ct_classes = {0: '__bgr', 1: 'GGO', 2: 'CL'}
ct_colors = {1: 'red', 2: 'blue', 'mask_cols': np.array([[255, 0, 0], [0, 0, 255]])}
if not save_dir in os.listdir('.'):
os.mkdir(save_dir)
# model name from config, not checkpoint
if model_name is None:
model_name = "maskrcnn_segmentation"
elif model_name is not None and config.model_name != model_name:
print("Using model name from the config.")
model_name = config.model_name
# run the inference with provided hyperparameters
test_ims = os.listdir(os.path.join(data_dir, img_dir))
for j, ims in enumerate(test_ims):
step(os.path.join(os.path.join(data_dir, img_dir), ims), device, maskrcnn_model, model_name,
confidence_threshold, mask_threshold, save_dir, ct_classes, ct_colors, j)
end_time = time.time()
print("Inference took {0:.1f} seconds".format(end_time - start_time))
def test_step(image, device, model, model_name, theta_conf, theta_mask, save_dir, cls, cols, num):
im = PILImage.open(image)
# convert image to RGB, remove the alpha channel
if im.mode != 'RGB':
im = im.convert(mode='RGB')
img = np.array(im)
# copy image to make background for plotting
bgr_img = copy.deepcopy(img)
if img.shape[2] > 3:
img = img[:, :, :3]
# torchvision transforms, the rest Mask R-CNN does internally
t_ = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor()])
img = t_(img).to(device)
out = model([img])
# scores + bounding boxes + labels + masks
scores = out[0]['scores']
bboxes = out[0]['boxes']
classes = out[0]['labels']
mask = out[0]['masks']
# this is the array for all masks
best_scores = scores[scores > theta_conf]
# Are there any detections with confidence above the threshold?
if len(best_scores):
best_idx = np.where(scores > theta_conf)
best_bboxes = bboxes[best_idx]
best_classes = classes[best_idx]
best_masks = mask[best_idx]
print('bm', best_masks.shape)
mask_array = np.zeros([best_masks[0].shape[1], best_masks[0].shape[2], 3], dtype=np.uint8)
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(12, 6)
ax.axis("off")
# plot predictions
for idx, dets in enumerate(best_bboxes):
found_masks = best_masks[idx][0].detach().clone().to(device).numpy()
pred_class = best_classes[idx].item()
pred_col_n = cols[pred_class]
pred_class_txt = cls[pred_class]
pred_col = cols['mask_cols'][pred_class - 1]
mask_array[found_masks > theta_mask] = pred_col
rect = Rectangle((dets[0], dets[1]), dets[2] - dets[0], dets[3] - dets[1], linewidth=1,
edgecolor=pred_col_n, facecolor='none', linestyle="--")
ax.text(dets[0] + 40, dets[1], '{0:}'.format(pred_class_txt), fontsize=10, color=pred_col_n)
ax.text(dets[0], dets[1], '{0:.2f}'.format(best_scores[idx]), fontsize=10, color=pred_col_n)
ax.add_patch(rect)
added_image = cv2.addWeighted(bgr_img, 0.5, mask_array, 0.75, gamma=0)
ax.imshow(added_image)
fig.savefig(os.path.join(save_dir, model_name + "_" + str(num) + ".png"),
bbox_inches='tight', pad_inches=0.0)
else:
print("No detections")
# run the inference
if __name__ == '__main__':
config_test = config.get_config_pars("test")
main(config_test, test_step)