[61b4fe]: / evaluate_classifier.py

Download this file

130 lines (116 with data), 5.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# COVID-CT-Mask-Net
# Torchvision detection package is locally re-implemented
# Transformed into a classification model with Mask R-CNN backend
# by Alex Ter-Sarkisov@City, University of London
# alex.ter-sarkisov@city.ac.uk
# 2020
import os
import re
import sys
import time
import config_classifier as config
import cv2
#######################################
import models.mask_net as mask_net
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import utils
from PIL import Image as PILImage
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from models.mask_net.rpn import AnchorGenerator
def main(config, step):
torch.manual_seed(time.time())
start_time = time.time()
devices = ['cpu', 'cuda']
backbones = ['resnet50', 'resnet34', 'resnet18']
truncation_levels = ['0','1','2']
assert config.device in devices
assert config.backbone_name in backbones
assert config.truncation in truncation_levels
pretrained_model, model_name, test_data_dir, device, rpn_nms, roi_nms, backbone_name, truncation, roi_batch_size, n_c, s_features\
= config.ckpt, config.model_name, config.test_data_dir, config.device, config.rpn_nms_th, \
config.roi_nms_th, config.backbone_name, config.truncation, config.roi_batch_size, config.num_classes, config.s_features
if torch.cuda.is_available() and device == 'cuda':
device = torch.device('cuda')
else:
device = torch.device('cpu')
# either 2+1 or 1+1 classes
ckpt = torch.load(pretrained_model, map_location=device)
# 'box_detections_per_img': batch size input in module S
# 'box_score_thresh': negative to accept all predictions
covid_mask_net_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': roi_batch_size,
'box_nms_thresh': roi_nms, 'box_score_thresh': -0.01, 'rpn_nms_thresh': rpn_nms}
print(covid_mask_net_args)
# extract anchor generator from the checkpoint
sizes = ckpt['anchor_generator'].sizes
aspect_ratios = ckpt['anchor_generator'].aspect_ratios
anchor_generator = AnchorGenerator(sizes, aspect_ratios)
# Faster R-CNN interfaces, masks not implemented at this stage
box_head = TwoMLPHead(in_channels=256*7*7, representation_size=128)
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
# Mask prediction is not necessary, keep it for future extensions
covid_mask_net_args['rpn_anchor_generator'] = anchor_generator
covid_mask_net_args['box_predictor'] = box_predictor
covid_mask_net_args['box_head'] = box_head
# representation size of the S classification module
# these should be provided in the config
covid_mask_net_args['s_representation_size'] = s_features
# Instance of the model, copy weights
covid_mask_net_model = mask_net.fasterrcnn_resnet_fpn(backbone_name, truncation, **covid_mask_net_args)
covid_mask_net_model.load_state_dict(ckpt['model_weights'])
covid_mask_net_model.eval().to(device)
print(covid_mask_net_model)
# confusion matrix
confusion_matrix = torch.zeros(3, 3, dtype=torch.int32).to(device)
for idx, f in enumerate(os.listdir(test_data_dir)):
step(f, covid_mask_net_model, test_data_dir, device, confusion_matrix)
print("------------------------------------------")
print("Confusion Matrix for 3-class problem:")
print("0: Control, 1: Normal Pneumonia, 2: COVID")
print(confusion_matrix)
print("------------------------------------------")
# confusion matrix
cm = confusion_matrix.float()
cm[0, :].div_(cm[0, :].sum())
cm[1, :].div_(cm[1, :].sum())
cm[2, :].div_(cm[2, :].sum())
print("------------------------------------------")
print("Class Sensitivity:")
print(cm)
print("------------------------------------------")
print('Overall accuracy:')
print(confusion_matrix.diag().float().sum().div(confusion_matrix.sum()))
end_time = time.time()
print("Evaluation took {0:.1f} seconds".format(end_time - start_time))
def test_step(im_input, model, source_dir, device, c_matrix):
# CNCB NCOV datasets: the first integer is the correct class:
# 0: control
# 1: pneumonia
# 2: COVID
# extract the correct class from the file name
correct_class = int(im_input.split('/')[-1].split('_')[0])
im = PILImage.open(os.path.join(source_dir, im_input))
if im.mode != 'RGB':
im = im.convert(mode='RGB')
# get rid of alpha channel
img = np.array(im)
# print(img)
if img.shape[2] > 3:
img = img[:, :, :3]
t_ = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(512),
transforms.ToTensor()])
img = t_(img)
if device == torch.device('cuda'):
img = img.to(device)
out = model([img])
pred_class = out[0]['final_scores'].argmax().item()
# get confusion matrix
c_matrix[correct_class, pred_class] += 1
# run the inference
if __name__ == '__main__':
config_test = config.get_config_pars_classifier("test")
main(config_test, test_step)