a b/evaluate_classifier.py
1
# COVID-CT-Mask-Net
2
# Torchvision detection package is locally re-implemented
3
# Transformed into a classification model with Mask R-CNN backend
4
# by Alex Ter-Sarkisov@City, University of London
5
# alex.ter-sarkisov@city.ac.uk
6
# 2020
7
import os
8
import re
9
import sys
10
import time
11
import config_classifier as config
12
import cv2
13
#######################################
14
import models.mask_net as mask_net
15
import numpy as np
16
import torch
17
import torchvision
18
import torchvision.transforms as transforms
19
import utils
20
from PIL import Image as PILImage
21
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
22
from models.mask_net.rpn import AnchorGenerator
23
24
25
26
def main(config, step):
27
    torch.manual_seed(time.time())
28
    start_time = time.time()
29
    devices = ['cpu', 'cuda']
30
    backbones = ['resnet50', 'resnet34', 'resnet18']
31
    truncation_levels = ['0','1','2']
32
33
    assert config.device in devices
34
    assert config.backbone_name in backbones
35
    assert config.truncation in truncation_levels
36
37
    pretrained_model, model_name, test_data_dir, device, rpn_nms, roi_nms, backbone_name, truncation, roi_batch_size, n_c, s_features\
38
              = config.ckpt, config.model_name, config.test_data_dir, config.device, config.rpn_nms_th, \
39
                config.roi_nms_th, config.backbone_name, config.truncation, config.roi_batch_size, config.num_classes, config.s_features
40
41
    if torch.cuda.is_available() and device == 'cuda':
42
        device = torch.device('cuda')
43
    else:
44
        device = torch.device('cpu')
45
    # either 2+1 or 1+1 classes
46
    ckpt = torch.load(pretrained_model, map_location=device)
47
    # 'box_detections_per_img': batch size input in module S
48
    # 'box_score_thresh': negative to accept all predictions
49
    covid_mask_net_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': roi_batch_size,
50
                           'box_nms_thresh': roi_nms, 'box_score_thresh': -0.01, 'rpn_nms_thresh': rpn_nms}
51
52
    print(covid_mask_net_args)
53
    # extract anchor generator from the checkpoint
54
    sizes = ckpt['anchor_generator'].sizes
55
    aspect_ratios = ckpt['anchor_generator'].aspect_ratios
56
    anchor_generator = AnchorGenerator(sizes, aspect_ratios)
57
    # Faster R-CNN interfaces, masks not implemented at this stage
58
    box_head = TwoMLPHead(in_channels=256*7*7, representation_size=128)
59
    box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
60
    # Mask prediction is not necessary, keep it for future extensions
61
    covid_mask_net_args['rpn_anchor_generator'] = anchor_generator
62
    covid_mask_net_args['box_predictor'] = box_predictor
63
    covid_mask_net_args['box_head'] = box_head
64
    # representation size of the S classification module
65
    # these should be provided in the config
66
    covid_mask_net_args['s_representation_size'] = s_features
67
    # Instance of the model, copy weights
68
    covid_mask_net_model = mask_net.fasterrcnn_resnet_fpn(backbone_name, truncation, **covid_mask_net_args)
69
    covid_mask_net_model.load_state_dict(ckpt['model_weights'])
70
    covid_mask_net_model.eval().to(device)
71
    print(covid_mask_net_model)
72
    # confusion matrix
73
    confusion_matrix = torch.zeros(3, 3, dtype=torch.int32).to(device)
74
75
    for idx, f in enumerate(os.listdir(test_data_dir)):
76
        step(f, covid_mask_net_model, test_data_dir, device, confusion_matrix)
77
78
    print("------------------------------------------")
79
    print("Confusion Matrix for 3-class problem:")
80
    print("0: Control, 1: Normal Pneumonia, 2: COVID")
81
    print(confusion_matrix)
82
    print("------------------------------------------")
83
    # confusion matrix
84
    cm = confusion_matrix.float()
85
    cm[0, :].div_(cm[0, :].sum())
86
    cm[1, :].div_(cm[1, :].sum())
87
    cm[2, :].div_(cm[2, :].sum())
88
    print("------------------------------------------")
89
    print("Class Sensitivity:")
90
    print(cm)
91
    print("------------------------------------------")
92
    print('Overall accuracy:')
93
    print(confusion_matrix.diag().float().sum().div(confusion_matrix.sum()))
94
    end_time = time.time()
95
    print("Evaluation took {0:.1f} seconds".format(end_time - start_time))
96
97
def test_step(im_input, model, source_dir, device, c_matrix):
98
    # CNCB NCOV datasets: the first integer is the correct class:
99
    # 0: control
100
    # 1: pneumonia
101
    # 2: COVID
102
    # extract the correct class from the file name
103
    correct_class = int(im_input.split('/')[-1].split('_')[0])
104
    im = PILImage.open(os.path.join(source_dir, im_input))
105
    if im.mode != 'RGB':
106
        im = im.convert(mode='RGB')
107
    # get rid of alpha channel
108
    img = np.array(im)
109
    # print(img)
110
    if img.shape[2] > 3:
111
        img = img[:, :, :3]
112
113
    t_ = transforms.Compose([
114
        transforms.ToPILImage(),
115
        transforms.Resize(512),
116
        transforms.ToTensor()])
117
    img = t_(img)
118
    if device == torch.device('cuda'):
119
        img = img.to(device)
120
    out = model([img])
121
    pred_class = out[0]['final_scores'].argmax().item()
122
    # get confusion matrix
123
    c_matrix[correct_class, pred_class] += 1
124
125
126
# run the inference
127
if __name__ == '__main__':
128
    config_test = config.get_config_pars_classifier("test")
129
    main(config_test, test_step)