Diff of /evaluation_mean_ap.py [000000] .. [a49583]

Switch to unified view

a b/evaluation_mean_ap.py
1
2
# Mask R-CNN model for lesion segmentation in chest CT scans
3
# Torchvision detection package is locally re-implemented
4
# by Alex Ter-Sarkisov@City, University of London
5
# 2020
6
7
import argparse
8
import os
9
from collections import OrderedDict
10
11
import config_segmentation as config
12
import torch
13
import torch.utils.data as data
14
import torchvision
15
# implementation of the mAP
16
import utils
17
from datasets import dataset_segmentation as dataset
18
import models.mask_net as mask_net
19
from models.mask_net.rpn_segmentation import AnchorGenerator
20
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
21
from models.mask_net.covid_mask_net import MaskRCNNHeads, MaskRCNNPredictor
22
23
def main(config, step):
24
    devices = ['cpu', 'cuda']
25
    mask_classes = ['both', 'ggo', 'merge']
26
    backbones = ['resnet50', 'resnet34', 'resnet18']
27
    truncation_levels = ['0','1','2']
28
    assert config.device in devices
29
    assert config.backbone_name in backbones
30
    assert config.truncation in truncation_levels
31
32
    if config.device == 'cuda' and torch.cuda.is_available():
33
        device = torch.device('cuda')
34
    else:
35
        device = torch.device('cpu')
36
    #
37
    model_name = None
38
    ckpt = torch.load(config.ckpt, map_location=device)
39
    if 'model_name' in ckpt.keys():
40
        model_name = ckpt['model_name']
41
42
    device = torch.device('cpu')
43
    if torch.cuda.is_available():
44
        device = torch.device('cuda')
45
46
    # get the thresholds
47
    confidence_threshold, mask_threshold, save_dir, data_dir, img_dir, gt_dir, mask_type, rpn_nms, roi_nms, backbone_name, truncation \
48
        = config.confidence_th, config.mask_logits_th, config.save_dir, config.test_data_dir, config.test_imgs_dir, \
49
        config.gt_dir, config.mask_type, config.rpn_nms_th, config.roi_nms_th, config.backbone_name, config.truncation
50
51
    if model_name is None:
52
        model_name = "maskrcnn_segmentation"
53
    elif model_name is not None and config.model_name != model_name:
54
        print("Using model name from the config.")
55
        model_name = config.model_name
56
57
    # either 2+1 or 1+1 classes
58
    assert mask_type in mask_classes
59
    if mask_type == "both":
60
        n_c = 3
61
    else:
62
        n_c = 2
63
    # dataset interface
64
    dataset_covid_eval_pars = {'stage': 'eval', 'gt': os.path.join(data_dir, gt_dir),
65
                               'data': os.path.join(data_dir, img_dir), 'mask_type': mask_type, 'ignore_small':True}
66
    datapoint_eval_covid = dataset.CovidCTData(**dataset_covid_eval_pars)
67
    dataloader_covid_eval_pars = {'shuffle': False, 'batch_size': 1}
68
    dataloader_eval_covid = data.DataLoader(datapoint_eval_covid, **dataloader_covid_eval_pars)
69
    # MASK R-CNN model
70
    # Alex: these settings could also be added to the config
71
    ckpt = torch.load(config.ckpt, map_location=device)
72
    sizes = ckpt['anchor_generator'].sizes
73
    aspect_ratios = ckpt['anchor_generator'].aspect_ratios
74
    anchor_generator = AnchorGenerator(sizes, aspect_ratios)
75
    print("Anchors: ", anchor_generator.sizes, anchor_generator.aspect_ratios)
76
77
    # create modules
78
    # this assumes FPN with 256 channels
79
    box_head = TwoMLPHead(in_channels=7 * 7 * 256, representation_size=128)
80
    if backbone_name == 'resnet50':
81
       maskrcnn_heads = None
82
       box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
83
       mask_predictor = MaskRCNNPredictor(in_channels=256, dim_reduced=256, num_classes=n_c)
84
    else:
85
       #Backbone->FPN->boxhead->boxpredictor
86
       box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
87
       maskrcnn_heads = MaskRCNNHeads(in_channels=256, layers=(128,), dilation=1)
88
       mask_predictor = MaskRCNNPredictor(in_channels=128, dim_reduced=128, num_classes=n_c)
89
90
    # keyword arguments
91
    maskrcnn_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': 128,
92
                     'box_nms_thresh': roi_nms, 'box_score_thresh': confidence_threshold, 'rpn_nms_thresh': rpn_nms,
93
                     'box_head': box_head, 'rpn_anchor_generator': anchor_generator, 'mask_head':maskrcnn_heads,
94
                     'mask_predictor': mask_predictor, 'box_predictor': box_predictor}
95
96
    # Instantiate the segmentation model
97
    maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name, truncation, pretrained_backbone=False, **maskrcnn_args)
98
    # Load weights
99
    maskrcnn_model.load_state_dict(ckpt['model_weights'])
100
    # Set to the evaluation mode
101
    print(maskrcnn_model)
102
    maskrcnn_model.eval().to(device)
103
    # IoU thresholds. By default the model computes AP for each threshold between 0.5 and 0.95 with the step of 0.05
104
    thresholds = torch.arange(0.5, 1, 0.05).to(device)
105
    mean_aps_all_th = torch.zeros(thresholds.size()[0]).to(device)
106
    ap_th = OrderedDict()
107
    # run the loop for all thresholds
108
    for t, th in enumerate(thresholds):
109
        # main method
110
        ap = step(maskrcnn_model, th, dataloader_eval_covid, device, mask_threshold)
111
        mean_aps_all_th[t] = ap
112
        th_name = 'AP@{0:.2f}'.format(th)
113
        ap_th[th_name] = ap
114
    print("Done evaluation for {}".format(model_name))
115
    print("mAP:{0:.4f}".format(mean_aps_all_th.mean().item()))
116
    for k, aps in ap_th.items():
117
        print("{0:}:{1:.4f}".format(k, aps))
118
119
120
def compute_map(model, iou_th, dl, device, mask_th):
121
    mean_aps_this_th = torch.zeros(len(dl), dtype=torch.float)
122
    for v, b in enumerate(dl):
123
        x, y = b
124
        if device == torch.device('cuda'):
125
            x, y['labels'], y['boxes'], y['masks'] = x.to(device), y['labels'].to(device), y['boxes'].to(device), y[
126
                'masks'].to(device)
127
        lab = {'boxes': y['boxes'].squeeze_(0), 'labels': y['labels'].squeeze_(0), 'masks': y['masks'].squeeze_(0)}
128
        image = [x.squeeze_(0)]  # remove the batch dimension
129
        out = model(image)
130
        # scores + bounding boxes + labels + masks
131
        scores = out[0]['scores']
132
        bboxes = out[0]['boxes']
133
        classes = out[0]['labels']
134
        # remove the empty dimension,
135
        # output_size x 512 x 512
136
        predict_mask = out[0]['masks'].squeeze_(1) > mask_th
137
        if len(scores) > 0 and len(lab['labels']) > 0:
138
            # threshold for the masks:
139
            ap, _, _, _ = utils.compute_ap(lab['boxes'], lab['labels'], lab['masks'], bboxes, classes, scores,
140
                                           predict_mask, iou_threshold=iou_th)
141
            mean_aps_this_th[v] = ap
142
        elif not len(scores) and not len(lab['labels']):
143
            mean_aps_this_th[v] = 1
144
        elif not len(scores) and len(lab['labels']) > 0:
145
            continue
146
        elif len(scores) > 0 and not len(y['labels']):
147
            continue
148
    return mean_aps_this_th.mean().item()
149
150
151
if __name__ == "__main__":
152
    config_mean_ap = config.get_config_pars("precision")
153
    main(config_mean_ap, compute_map)