|
a |
|
b/evaluation_mean_ap.py |
|
|
1 |
|
|
|
2 |
# Mask R-CNN model for lesion segmentation in chest CT scans |
|
|
3 |
# Torchvision detection package is locally re-implemented |
|
|
4 |
# by Alex Ter-Sarkisov@City, University of London |
|
|
5 |
# 2020 |
|
|
6 |
|
|
|
7 |
import argparse |
|
|
8 |
import os |
|
|
9 |
from collections import OrderedDict |
|
|
10 |
|
|
|
11 |
import config_segmentation as config |
|
|
12 |
import torch |
|
|
13 |
import torch.utils.data as data |
|
|
14 |
import torchvision |
|
|
15 |
# implementation of the mAP |
|
|
16 |
import utils |
|
|
17 |
from datasets import dataset_segmentation as dataset |
|
|
18 |
import models.mask_net as mask_net |
|
|
19 |
from models.mask_net.rpn_segmentation import AnchorGenerator |
|
|
20 |
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead |
|
|
21 |
from models.mask_net.covid_mask_net import MaskRCNNHeads, MaskRCNNPredictor |
|
|
22 |
|
|
|
23 |
def main(config, step): |
|
|
24 |
devices = ['cpu', 'cuda'] |
|
|
25 |
mask_classes = ['both', 'ggo', 'merge'] |
|
|
26 |
backbones = ['resnet50', 'resnet34', 'resnet18'] |
|
|
27 |
truncation_levels = ['0','1','2'] |
|
|
28 |
assert config.device in devices |
|
|
29 |
assert config.backbone_name in backbones |
|
|
30 |
assert config.truncation in truncation_levels |
|
|
31 |
|
|
|
32 |
if config.device == 'cuda' and torch.cuda.is_available(): |
|
|
33 |
device = torch.device('cuda') |
|
|
34 |
else: |
|
|
35 |
device = torch.device('cpu') |
|
|
36 |
# |
|
|
37 |
model_name = None |
|
|
38 |
ckpt = torch.load(config.ckpt, map_location=device) |
|
|
39 |
if 'model_name' in ckpt.keys(): |
|
|
40 |
model_name = ckpt['model_name'] |
|
|
41 |
|
|
|
42 |
device = torch.device('cpu') |
|
|
43 |
if torch.cuda.is_available(): |
|
|
44 |
device = torch.device('cuda') |
|
|
45 |
|
|
|
46 |
# get the thresholds |
|
|
47 |
confidence_threshold, mask_threshold, save_dir, data_dir, img_dir, gt_dir, mask_type, rpn_nms, roi_nms, backbone_name, truncation \ |
|
|
48 |
= config.confidence_th, config.mask_logits_th, config.save_dir, config.test_data_dir, config.test_imgs_dir, \ |
|
|
49 |
config.gt_dir, config.mask_type, config.rpn_nms_th, config.roi_nms_th, config.backbone_name, config.truncation |
|
|
50 |
|
|
|
51 |
if model_name is None: |
|
|
52 |
model_name = "maskrcnn_segmentation" |
|
|
53 |
elif model_name is not None and config.model_name != model_name: |
|
|
54 |
print("Using model name from the config.") |
|
|
55 |
model_name = config.model_name |
|
|
56 |
|
|
|
57 |
# either 2+1 or 1+1 classes |
|
|
58 |
assert mask_type in mask_classes |
|
|
59 |
if mask_type == "both": |
|
|
60 |
n_c = 3 |
|
|
61 |
else: |
|
|
62 |
n_c = 2 |
|
|
63 |
# dataset interface |
|
|
64 |
dataset_covid_eval_pars = {'stage': 'eval', 'gt': os.path.join(data_dir, gt_dir), |
|
|
65 |
'data': os.path.join(data_dir, img_dir), 'mask_type': mask_type, 'ignore_small':True} |
|
|
66 |
datapoint_eval_covid = dataset.CovidCTData(**dataset_covid_eval_pars) |
|
|
67 |
dataloader_covid_eval_pars = {'shuffle': False, 'batch_size': 1} |
|
|
68 |
dataloader_eval_covid = data.DataLoader(datapoint_eval_covid, **dataloader_covid_eval_pars) |
|
|
69 |
# MASK R-CNN model |
|
|
70 |
# Alex: these settings could also be added to the config |
|
|
71 |
ckpt = torch.load(config.ckpt, map_location=device) |
|
|
72 |
sizes = ckpt['anchor_generator'].sizes |
|
|
73 |
aspect_ratios = ckpt['anchor_generator'].aspect_ratios |
|
|
74 |
anchor_generator = AnchorGenerator(sizes, aspect_ratios) |
|
|
75 |
print("Anchors: ", anchor_generator.sizes, anchor_generator.aspect_ratios) |
|
|
76 |
|
|
|
77 |
# create modules |
|
|
78 |
# this assumes FPN with 256 channels |
|
|
79 |
box_head = TwoMLPHead(in_channels=7 * 7 * 256, representation_size=128) |
|
|
80 |
if backbone_name == 'resnet50': |
|
|
81 |
maskrcnn_heads = None |
|
|
82 |
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c) |
|
|
83 |
mask_predictor = MaskRCNNPredictor(in_channels=256, dim_reduced=256, num_classes=n_c) |
|
|
84 |
else: |
|
|
85 |
#Backbone->FPN->boxhead->boxpredictor |
|
|
86 |
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c) |
|
|
87 |
maskrcnn_heads = MaskRCNNHeads(in_channels=256, layers=(128,), dilation=1) |
|
|
88 |
mask_predictor = MaskRCNNPredictor(in_channels=128, dim_reduced=128, num_classes=n_c) |
|
|
89 |
|
|
|
90 |
# keyword arguments |
|
|
91 |
maskrcnn_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': 128, |
|
|
92 |
'box_nms_thresh': roi_nms, 'box_score_thresh': confidence_threshold, 'rpn_nms_thresh': rpn_nms, |
|
|
93 |
'box_head': box_head, 'rpn_anchor_generator': anchor_generator, 'mask_head':maskrcnn_heads, |
|
|
94 |
'mask_predictor': mask_predictor, 'box_predictor': box_predictor} |
|
|
95 |
|
|
|
96 |
# Instantiate the segmentation model |
|
|
97 |
maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name, truncation, pretrained_backbone=False, **maskrcnn_args) |
|
|
98 |
# Load weights |
|
|
99 |
maskrcnn_model.load_state_dict(ckpt['model_weights']) |
|
|
100 |
# Set to the evaluation mode |
|
|
101 |
print(maskrcnn_model) |
|
|
102 |
maskrcnn_model.eval().to(device) |
|
|
103 |
# IoU thresholds. By default the model computes AP for each threshold between 0.5 and 0.95 with the step of 0.05 |
|
|
104 |
thresholds = torch.arange(0.5, 1, 0.05).to(device) |
|
|
105 |
mean_aps_all_th = torch.zeros(thresholds.size()[0]).to(device) |
|
|
106 |
ap_th = OrderedDict() |
|
|
107 |
# run the loop for all thresholds |
|
|
108 |
for t, th in enumerate(thresholds): |
|
|
109 |
# main method |
|
|
110 |
ap = step(maskrcnn_model, th, dataloader_eval_covid, device, mask_threshold) |
|
|
111 |
mean_aps_all_th[t] = ap |
|
|
112 |
th_name = 'AP@{0:.2f}'.format(th) |
|
|
113 |
ap_th[th_name] = ap |
|
|
114 |
print("Done evaluation for {}".format(model_name)) |
|
|
115 |
print("mAP:{0:.4f}".format(mean_aps_all_th.mean().item())) |
|
|
116 |
for k, aps in ap_th.items(): |
|
|
117 |
print("{0:}:{1:.4f}".format(k, aps)) |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
def compute_map(model, iou_th, dl, device, mask_th): |
|
|
121 |
mean_aps_this_th = torch.zeros(len(dl), dtype=torch.float) |
|
|
122 |
for v, b in enumerate(dl): |
|
|
123 |
x, y = b |
|
|
124 |
if device == torch.device('cuda'): |
|
|
125 |
x, y['labels'], y['boxes'], y['masks'] = x.to(device), y['labels'].to(device), y['boxes'].to(device), y[ |
|
|
126 |
'masks'].to(device) |
|
|
127 |
lab = {'boxes': y['boxes'].squeeze_(0), 'labels': y['labels'].squeeze_(0), 'masks': y['masks'].squeeze_(0)} |
|
|
128 |
image = [x.squeeze_(0)] # remove the batch dimension |
|
|
129 |
out = model(image) |
|
|
130 |
# scores + bounding boxes + labels + masks |
|
|
131 |
scores = out[0]['scores'] |
|
|
132 |
bboxes = out[0]['boxes'] |
|
|
133 |
classes = out[0]['labels'] |
|
|
134 |
# remove the empty dimension, |
|
|
135 |
# output_size x 512 x 512 |
|
|
136 |
predict_mask = out[0]['masks'].squeeze_(1) > mask_th |
|
|
137 |
if len(scores) > 0 and len(lab['labels']) > 0: |
|
|
138 |
# threshold for the masks: |
|
|
139 |
ap, _, _, _ = utils.compute_ap(lab['boxes'], lab['labels'], lab['masks'], bboxes, classes, scores, |
|
|
140 |
predict_mask, iou_threshold=iou_th) |
|
|
141 |
mean_aps_this_th[v] = ap |
|
|
142 |
elif not len(scores) and not len(lab['labels']): |
|
|
143 |
mean_aps_this_th[v] = 1 |
|
|
144 |
elif not len(scores) and len(lab['labels']) > 0: |
|
|
145 |
continue |
|
|
146 |
elif len(scores) > 0 and not len(y['labels']): |
|
|
147 |
continue |
|
|
148 |
return mean_aps_this_th.mean().item() |
|
|
149 |
|
|
|
150 |
|
|
|
151 |
if __name__ == "__main__": |
|
|
152 |
config_mean_ap = config.get_config_pars("precision") |
|
|
153 |
main(config_mean_ap, compute_map) |