--- a +++ b/evaluation_mean_ap.py @@ -0,0 +1,153 @@ + +# Mask R-CNN model for lesion segmentation in chest CT scans +# Torchvision detection package is locally re-implemented +# by Alex Ter-Sarkisov@City, University of London +# 2020 + +import argparse +import os +from collections import OrderedDict + +import config_segmentation as config +import torch +import torch.utils.data as data +import torchvision +# implementation of the mAP +import utils +from datasets import dataset_segmentation as dataset +import models.mask_net as mask_net +from models.mask_net.rpn_segmentation import AnchorGenerator +from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead +from models.mask_net.covid_mask_net import MaskRCNNHeads, MaskRCNNPredictor + +def main(config, step): + devices = ['cpu', 'cuda'] + mask_classes = ['both', 'ggo', 'merge'] + backbones = ['resnet50', 'resnet34', 'resnet18'] + truncation_levels = ['0','1','2'] + assert config.device in devices + assert config.backbone_name in backbones + assert config.truncation in truncation_levels + + if config.device == 'cuda' and torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + # + model_name = None + ckpt = torch.load(config.ckpt, map_location=device) + if 'model_name' in ckpt.keys(): + model_name = ckpt['model_name'] + + device = torch.device('cpu') + if torch.cuda.is_available(): + device = torch.device('cuda') + + # get the thresholds + confidence_threshold, mask_threshold, save_dir, data_dir, img_dir, gt_dir, mask_type, rpn_nms, roi_nms, backbone_name, truncation \ + = config.confidence_th, config.mask_logits_th, config.save_dir, config.test_data_dir, config.test_imgs_dir, \ + config.gt_dir, config.mask_type, config.rpn_nms_th, config.roi_nms_th, config.backbone_name, config.truncation + + if model_name is None: + model_name = "maskrcnn_segmentation" + elif model_name is not None and config.model_name != model_name: + print("Using model name from the config.") + model_name = config.model_name + + # either 2+1 or 1+1 classes + assert mask_type in mask_classes + if mask_type == "both": + n_c = 3 + else: + n_c = 2 + # dataset interface + dataset_covid_eval_pars = {'stage': 'eval', 'gt': os.path.join(data_dir, gt_dir), + 'data': os.path.join(data_dir, img_dir), 'mask_type': mask_type, 'ignore_small':True} + datapoint_eval_covid = dataset.CovidCTData(**dataset_covid_eval_pars) + dataloader_covid_eval_pars = {'shuffle': False, 'batch_size': 1} + dataloader_eval_covid = data.DataLoader(datapoint_eval_covid, **dataloader_covid_eval_pars) + # MASK R-CNN model + # Alex: these settings could also be added to the config + ckpt = torch.load(config.ckpt, map_location=device) + sizes = ckpt['anchor_generator'].sizes + aspect_ratios = ckpt['anchor_generator'].aspect_ratios + anchor_generator = AnchorGenerator(sizes, aspect_ratios) + print("Anchors: ", anchor_generator.sizes, anchor_generator.aspect_ratios) + + # create modules + # this assumes FPN with 256 channels + box_head = TwoMLPHead(in_channels=7 * 7 * 256, representation_size=128) + if backbone_name == 'resnet50': + maskrcnn_heads = None + box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c) + mask_predictor = MaskRCNNPredictor(in_channels=256, dim_reduced=256, num_classes=n_c) + else: + #Backbone->FPN->boxhead->boxpredictor + box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c) + maskrcnn_heads = MaskRCNNHeads(in_channels=256, layers=(128,), dilation=1) + mask_predictor = MaskRCNNPredictor(in_channels=128, dim_reduced=128, num_classes=n_c) + + # keyword arguments + maskrcnn_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': 128, + 'box_nms_thresh': roi_nms, 'box_score_thresh': confidence_threshold, 'rpn_nms_thresh': rpn_nms, + 'box_head': box_head, 'rpn_anchor_generator': anchor_generator, 'mask_head':maskrcnn_heads, + 'mask_predictor': mask_predictor, 'box_predictor': box_predictor} + + # Instantiate the segmentation model + maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name, truncation, pretrained_backbone=False, **maskrcnn_args) + # Load weights + maskrcnn_model.load_state_dict(ckpt['model_weights']) + # Set to the evaluation mode + print(maskrcnn_model) + maskrcnn_model.eval().to(device) + # IoU thresholds. By default the model computes AP for each threshold between 0.5 and 0.95 with the step of 0.05 + thresholds = torch.arange(0.5, 1, 0.05).to(device) + mean_aps_all_th = torch.zeros(thresholds.size()[0]).to(device) + ap_th = OrderedDict() + # run the loop for all thresholds + for t, th in enumerate(thresholds): + # main method + ap = step(maskrcnn_model, th, dataloader_eval_covid, device, mask_threshold) + mean_aps_all_th[t] = ap + th_name = 'AP@{0:.2f}'.format(th) + ap_th[th_name] = ap + print("Done evaluation for {}".format(model_name)) + print("mAP:{0:.4f}".format(mean_aps_all_th.mean().item())) + for k, aps in ap_th.items(): + print("{0:}:{1:.4f}".format(k, aps)) + + +def compute_map(model, iou_th, dl, device, mask_th): + mean_aps_this_th = torch.zeros(len(dl), dtype=torch.float) + for v, b in enumerate(dl): + x, y = b + if device == torch.device('cuda'): + x, y['labels'], y['boxes'], y['masks'] = x.to(device), y['labels'].to(device), y['boxes'].to(device), y[ + 'masks'].to(device) + lab = {'boxes': y['boxes'].squeeze_(0), 'labels': y['labels'].squeeze_(0), 'masks': y['masks'].squeeze_(0)} + image = [x.squeeze_(0)] # remove the batch dimension + out = model(image) + # scores + bounding boxes + labels + masks + scores = out[0]['scores'] + bboxes = out[0]['boxes'] + classes = out[0]['labels'] + # remove the empty dimension, + # output_size x 512 x 512 + predict_mask = out[0]['masks'].squeeze_(1) > mask_th + if len(scores) > 0 and len(lab['labels']) > 0: + # threshold for the masks: + ap, _, _, _ = utils.compute_ap(lab['boxes'], lab['labels'], lab['masks'], bboxes, classes, scores, + predict_mask, iou_threshold=iou_th) + mean_aps_this_th[v] = ap + elif not len(scores) and not len(lab['labels']): + mean_aps_this_th[v] = 1 + elif not len(scores) and len(lab['labels']) > 0: + continue + elif len(scores) > 0 and not len(y['labels']): + continue + return mean_aps_this_th.mean().item() + + +if __name__ == "__main__": + config_mean_ap = config.get_config_pars("precision") + main(config_mean_ap, compute_map)