Diff of /inference_segmentation.py [000000] .. [a49583]

Switch to unified view

a b/inference_segmentation.py
1
2
# Mask R-CNN model for lesion segmentation in chest CT scans
3
# Torchvision detection package is locally re-implemented
4
# by Alex Ter-Sarkisov@City, University of London
5
# alex.ter-sarkisov@city.ac.uk
6
# 2020
7
import argparse
8
import time
9
import pickle
10
import copy
11
import torch
12
import torchvision
13
import numpy as np
14
import os
15
import cv2
16
import models.mask_net as mask_net
17
from models.mask_net.rpn_segmentation import AnchorGenerator
18
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
19
from models.mask_net.covid_mask_net import MaskRCNNHeads, MaskRCNNPredictor
20
from torchvision import transforms
21
from torch.utils import data
22
import torch.utils as utils
23
import datasets.dataset_segmentation as dataset
24
from PIL import Image as PILImage
25
import matplotlib.pyplot as plt
26
import matplotlib.patches as patches
27
from matplotlib.patches import Rectangle
28
import utils
29
import config_segmentation as config
30
31
32
# main method
33
def main(config, step):
34
    devices = ['cpu', 'cuda']
35
    mask_classes = ['both', 'ggo', 'merge']
36
    backbones = ['resnet50', 'resnet34', 'resnet18']
37
    truncation_levels = ['0','1','2']
38
    assert config.device in devices
39
    assert config.backbone_name in backbones
40
    assert config.truncation in truncation_levels
41
42
    assert config.mask_type in mask_classes
43
    if config.device == 'cuda' and torch.cuda.is_available():
44
        device = torch.device('cuda')
45
    else:
46
        device = torch.device('cpu')
47
    # get the configuration
48
    # get the thresholds
49
    confidence_threshold, mask_threshold, save_dir, data_dir, img_dir, gt_dir, mask_type, rpn_nms, roi_nms, backbone_name, truncation \
50
        = config.confidence_th, config.mask_logits_th, config.save_dir, config.test_data_dir, config.test_imgs_dir, \
51
        config.gt_dir, config.mask_type, config.rpn_nms_th, config.roi_nms_th, config.backbone_name, config.truncation
52
53
    if mask_type == "both":
54
        n_c = 3
55
    else:
56
        n_c = 2
57
    ckpt = torch.load(config.ckpt, map_location=device)
58
59
    model_name = None
60
    if 'model_name' in ckpt.keys():
61
        model_name = ckpt['model_name']
62
    sizes = ckpt['anchor_generator'].sizes
63
    aspect_ratios = ckpt['anchor_generator'].aspect_ratios
64
    anchor_generator = AnchorGenerator(sizes, aspect_ratios)
65
    print("Anchors: ", anchor_generator.sizes, anchor_generator.aspect_ratios)
66
67
    # create modules
68
    # this assumes FPN with 256 channels
69
    box_head = TwoMLPHead(in_channels=7 * 7 * 256, representation_size=128)
70
    if backbone_name == 'resnet50':
71
       maskrcnn_heads = None
72
       box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
73
       mask_predictor = MaskRCNNPredictor(in_channels=256, dim_reduced=256, num_classes=n_c)
74
    else:
75
       #Backbone->FPN->boxhead->boxpredictor
76
       box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
77
       maskrcnn_heads = MaskRCNNHeads(in_channels=256, layers=(128,), dilation=1)
78
       mask_predictor = MaskRCNNPredictor(in_channels=128, dim_reduced=128, num_classes=n_c)
79
80
    # keyword arguments
81
    maskrcnn_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': 100,
82
                     'box_nms_thresh': roi_nms, 'box_score_thresh': confidence_threshold, 'rpn_nms_thresh': rpn_nms,
83
                     'box_head': box_head, 'rpn_anchor_generator': anchor_generator, 'mask_head':maskrcnn_heads,
84
                     'mask_predictor': mask_predictor, 'box_predictor': box_predictor}
85
86
    # Instantiate the segmentation model
87
    maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name, truncation, pretrained_backbone=False, **maskrcnn_args)
88
    # Load weights
89
    maskrcnn_model.load_state_dict(ckpt['model_weights'])
90
    # Set to evaluation mode
91
    print(maskrcnn_model)
92
    maskrcnn_model.eval().to(device)
93
94
    start_time = time.time()
95
    # get the correct masks and mask colors
96
    if mask_type == "ggo":
97
       ct_classes = {0: '__bgr', 1: 'GGO'}
98
       ct_colors = {1: 'red', 'mask_cols': np.array([[255, 0, 0]])}
99
    elif mask_type == "merge":
100
       ct_classes = {0: '__bgr', 1: 'Lesion'}
101
       ct_colors = {1: 'red', 'mask_cols': np.array([[255, 0, 0]])}
102
    elif mask_type == "both":
103
       ct_classes = {0: '__bgr', 1: 'GGO', 2: 'CL'}
104
       ct_colors = {1: 'red', 2: 'blue', 'mask_cols': np.array([[255, 0, 0], [0, 0, 255]])} 
105
106
    if not save_dir in os.listdir('.'):
107
       os.mkdir(save_dir)
108
109
    # model name from config, not checkpoint
110
    if model_name is None:
111
        model_name = "maskrcnn_segmentation"
112
    elif model_name is not None and config.model_name != model_name:
113
        print("Using model name from the config.")
114
        model_name = config.model_name
115
116
    # run the inference with provided hyperparameters
117
    test_ims = os.listdir(os.path.join(data_dir, img_dir))
118
    for j, ims in enumerate(test_ims):
119
        step(os.path.join(os.path.join(data_dir, img_dir), ims), device, maskrcnn_model, model_name,
120
             confidence_threshold, mask_threshold, save_dir, ct_classes, ct_colors, j)
121
    end_time = time.time()
122
    print("Inference took {0:.1f} seconds".format(end_time - start_time))
123
124
125
def test_step(image, device, model, model_name, theta_conf, theta_mask, save_dir, cls, cols, num):
126
    im = PILImage.open(image)
127
    # convert image to RGB, remove the alpha channel
128
    if im.mode != 'RGB':
129
        im = im.convert(mode='RGB')
130
    img = np.array(im)
131
    # copy image to make background for plotting
132
    bgr_img = copy.deepcopy(img)
133
    if img.shape[2] > 3:
134
        img = img[:, :, :3]
135
    # torchvision transforms, the rest Mask R-CNN does internally
136
    t_ = transforms.Compose([
137
        transforms.ToPILImage(),
138
        transforms.ToTensor()])
139
    img = t_(img).to(device)
140
    out = model([img])
141
    # scores + bounding boxes + labels + masks
142
    scores = out[0]['scores']
143
    bboxes = out[0]['boxes']
144
    classes = out[0]['labels']
145
    mask = out[0]['masks']
146
    # this is the array for all masks
147
    best_scores = scores[scores > theta_conf]
148
    # Are there any detections with confidence above the threshold?
149
    if len(best_scores):
150
        best_idx = np.where(scores > theta_conf)
151
        best_bboxes = bboxes[best_idx]
152
        best_classes = classes[best_idx]
153
        best_masks = mask[best_idx]
154
        print('bm', best_masks.shape)
155
        mask_array = np.zeros([best_masks[0].shape[1], best_masks[0].shape[2], 3], dtype=np.uint8)
156
        fig, ax = plt.subplots(1, 1)
157
        fig.set_size_inches(12, 6)
158
        ax.axis("off")
159
        # plot predictions
160
        for idx, dets in enumerate(best_bboxes):
161
            found_masks = best_masks[idx][0].detach().clone().to(device).numpy()
162
            pred_class = best_classes[idx].item()
163
            pred_col_n = cols[pred_class]
164
            pred_class_txt = cls[pred_class]
165
            pred_col = cols['mask_cols'][pred_class - 1]
166
            mask_array[found_masks > theta_mask] = pred_col
167
            rect = Rectangle((dets[0], dets[1]), dets[2] - dets[0], dets[3] - dets[1], linewidth=1,
168
                             edgecolor=pred_col_n, facecolor='none', linestyle="--")
169
            ax.text(dets[0] + 40, dets[1], '{0:}'.format(pred_class_txt), fontsize=10, color=pred_col_n)
170
            ax.text(dets[0], dets[1], '{0:.2f}'.format(best_scores[idx]), fontsize=10, color=pred_col_n)
171
            ax.add_patch(rect)
172
173
        added_image = cv2.addWeighted(bgr_img, 0.5, mask_array, 0.75, gamma=0)
174
        ax.imshow(added_image)
175
        fig.savefig(os.path.join(save_dir, model_name + "_" + str(num) + ".png"),
176
                    bbox_inches='tight', pad_inches=0.0)
177
178
    else:
179
        print("No detections")
180
181
# run the inference
182
if __name__ == '__main__':
183
    config_test = config.get_config_pars("test")
184
    main(config_test, test_step)