|
a |
|
b/evaluate_classifier.py |
|
|
1 |
# COVID-CT-Mask-Net |
|
|
2 |
# Torchvision detection package is locally re-implemented |
|
|
3 |
# Transformed into a classification model with Mask R-CNN backend |
|
|
4 |
# by Alex Ter-Sarkisov@City, University of London |
|
|
5 |
# alex.ter-sarkisov@city.ac.uk |
|
|
6 |
# 2020 |
|
|
7 |
import os |
|
|
8 |
import re |
|
|
9 |
import sys |
|
|
10 |
import time |
|
|
11 |
import config_classifier as config |
|
|
12 |
import cv2 |
|
|
13 |
####################################### |
|
|
14 |
import models.mask_net as mask_net |
|
|
15 |
import numpy as np |
|
|
16 |
import torch |
|
|
17 |
import torchvision |
|
|
18 |
import torchvision.transforms as transforms |
|
|
19 |
import utils |
|
|
20 |
from PIL import Image as PILImage |
|
|
21 |
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead |
|
|
22 |
from models.mask_net.rpn import AnchorGenerator |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def main(config, step): |
|
|
27 |
torch.manual_seed(time.time()) |
|
|
28 |
start_time = time.time() |
|
|
29 |
devices = ['cpu', 'cuda'] |
|
|
30 |
backbones = ['resnet50', 'resnet34', 'resnet18'] |
|
|
31 |
truncation_levels = ['0','1','2'] |
|
|
32 |
|
|
|
33 |
assert config.device in devices |
|
|
34 |
assert config.backbone_name in backbones |
|
|
35 |
assert config.truncation in truncation_levels |
|
|
36 |
|
|
|
37 |
pretrained_model, model_name, test_data_dir, device, rpn_nms, roi_nms, backbone_name, truncation, roi_batch_size, n_c, s_features\ |
|
|
38 |
= config.ckpt, config.model_name, config.test_data_dir, config.device, config.rpn_nms_th, \ |
|
|
39 |
config.roi_nms_th, config.backbone_name, config.truncation, config.roi_batch_size, config.num_classes, config.s_features |
|
|
40 |
|
|
|
41 |
if torch.cuda.is_available() and device == 'cuda': |
|
|
42 |
device = torch.device('cuda') |
|
|
43 |
else: |
|
|
44 |
device = torch.device('cpu') |
|
|
45 |
# either 2+1 or 1+1 classes |
|
|
46 |
ckpt = torch.load(pretrained_model, map_location=device) |
|
|
47 |
# 'box_detections_per_img': batch size input in module S |
|
|
48 |
# 'box_score_thresh': negative to accept all predictions |
|
|
49 |
covid_mask_net_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': roi_batch_size, |
|
|
50 |
'box_nms_thresh': roi_nms, 'box_score_thresh': -0.01, 'rpn_nms_thresh': rpn_nms} |
|
|
51 |
|
|
|
52 |
print(covid_mask_net_args) |
|
|
53 |
# extract anchor generator from the checkpoint |
|
|
54 |
sizes = ckpt['anchor_generator'].sizes |
|
|
55 |
aspect_ratios = ckpt['anchor_generator'].aspect_ratios |
|
|
56 |
anchor_generator = AnchorGenerator(sizes, aspect_ratios) |
|
|
57 |
# Faster R-CNN interfaces, masks not implemented at this stage |
|
|
58 |
box_head = TwoMLPHead(in_channels=256*7*7, representation_size=128) |
|
|
59 |
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c) |
|
|
60 |
# Mask prediction is not necessary, keep it for future extensions |
|
|
61 |
covid_mask_net_args['rpn_anchor_generator'] = anchor_generator |
|
|
62 |
covid_mask_net_args['box_predictor'] = box_predictor |
|
|
63 |
covid_mask_net_args['box_head'] = box_head |
|
|
64 |
# representation size of the S classification module |
|
|
65 |
# these should be provided in the config |
|
|
66 |
covid_mask_net_args['s_representation_size'] = s_features |
|
|
67 |
# Instance of the model, copy weights |
|
|
68 |
covid_mask_net_model = mask_net.fasterrcnn_resnet_fpn(backbone_name, truncation, **covid_mask_net_args) |
|
|
69 |
covid_mask_net_model.load_state_dict(ckpt['model_weights']) |
|
|
70 |
covid_mask_net_model.eval().to(device) |
|
|
71 |
print(covid_mask_net_model) |
|
|
72 |
# confusion matrix |
|
|
73 |
confusion_matrix = torch.zeros(3, 3, dtype=torch.int32).to(device) |
|
|
74 |
|
|
|
75 |
for idx, f in enumerate(os.listdir(test_data_dir)): |
|
|
76 |
step(f, covid_mask_net_model, test_data_dir, device, confusion_matrix) |
|
|
77 |
|
|
|
78 |
print("------------------------------------------") |
|
|
79 |
print("Confusion Matrix for 3-class problem:") |
|
|
80 |
print("0: Control, 1: Normal Pneumonia, 2: COVID") |
|
|
81 |
print(confusion_matrix) |
|
|
82 |
print("------------------------------------------") |
|
|
83 |
# confusion matrix |
|
|
84 |
cm = confusion_matrix.float() |
|
|
85 |
cm[0, :].div_(cm[0, :].sum()) |
|
|
86 |
cm[1, :].div_(cm[1, :].sum()) |
|
|
87 |
cm[2, :].div_(cm[2, :].sum()) |
|
|
88 |
print("------------------------------------------") |
|
|
89 |
print("Class Sensitivity:") |
|
|
90 |
print(cm) |
|
|
91 |
print("------------------------------------------") |
|
|
92 |
print('Overall accuracy:') |
|
|
93 |
print(confusion_matrix.diag().float().sum().div(confusion_matrix.sum())) |
|
|
94 |
end_time = time.time() |
|
|
95 |
print("Evaluation took {0:.1f} seconds".format(end_time - start_time)) |
|
|
96 |
|
|
|
97 |
def test_step(im_input, model, source_dir, device, c_matrix): |
|
|
98 |
# CNCB NCOV datasets: the first integer is the correct class: |
|
|
99 |
# 0: control |
|
|
100 |
# 1: pneumonia |
|
|
101 |
# 2: COVID |
|
|
102 |
# extract the correct class from the file name |
|
|
103 |
correct_class = int(im_input.split('/')[-1].split('_')[0]) |
|
|
104 |
im = PILImage.open(os.path.join(source_dir, im_input)) |
|
|
105 |
if im.mode != 'RGB': |
|
|
106 |
im = im.convert(mode='RGB') |
|
|
107 |
# get rid of alpha channel |
|
|
108 |
img = np.array(im) |
|
|
109 |
# print(img) |
|
|
110 |
if img.shape[2] > 3: |
|
|
111 |
img = img[:, :, :3] |
|
|
112 |
|
|
|
113 |
t_ = transforms.Compose([ |
|
|
114 |
transforms.ToPILImage(), |
|
|
115 |
transforms.Resize(512), |
|
|
116 |
transforms.ToTensor()]) |
|
|
117 |
img = t_(img) |
|
|
118 |
if device == torch.device('cuda'): |
|
|
119 |
img = img.to(device) |
|
|
120 |
out = model([img]) |
|
|
121 |
pred_class = out[0]['final_scores'].argmax().item() |
|
|
122 |
# get confusion matrix |
|
|
123 |
c_matrix[correct_class, pred_class] += 1 |
|
|
124 |
|
|
|
125 |
|
|
|
126 |
# run the inference |
|
|
127 |
if __name__ == '__main__': |
|
|
128 |
config_test = config.get_config_pars_classifier("test") |
|
|
129 |
main(config_test, test_step) |