Diff of /evaluate_classifier.py [000000] .. [a49583]

Switch to side-by-side view

--- a
+++ b/evaluate_classifier.py
@@ -0,0 +1,129 @@
+# COVID-CT-Mask-Net
+# Torchvision detection package is locally re-implemented
+# Transformed into a classification model with Mask R-CNN backend
+# by Alex Ter-Sarkisov@City, University of London
+# alex.ter-sarkisov@city.ac.uk
+# 2020
+import os
+import re
+import sys
+import time
+import config_classifier as config
+import cv2
+#######################################
+import models.mask_net as mask_net
+import numpy as np
+import torch
+import torchvision
+import torchvision.transforms as transforms
+import utils
+from PIL import Image as PILImage
+from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
+from models.mask_net.rpn import AnchorGenerator
+
+
+
+def main(config, step):
+    torch.manual_seed(time.time())
+    start_time = time.time()
+    devices = ['cpu', 'cuda']
+    backbones = ['resnet50', 'resnet34', 'resnet18']
+    truncation_levels = ['0','1','2']
+
+    assert config.device in devices
+    assert config.backbone_name in backbones
+    assert config.truncation in truncation_levels
+
+    pretrained_model, model_name, test_data_dir, device, rpn_nms, roi_nms, backbone_name, truncation, roi_batch_size, n_c, s_features\
+              = config.ckpt, config.model_name, config.test_data_dir, config.device, config.rpn_nms_th, \
+                config.roi_nms_th, config.backbone_name, config.truncation, config.roi_batch_size, config.num_classes, config.s_features
+
+    if torch.cuda.is_available() and device == 'cuda':
+        device = torch.device('cuda')
+    else:
+        device = torch.device('cpu')
+    # either 2+1 or 1+1 classes
+    ckpt = torch.load(pretrained_model, map_location=device)
+    # 'box_detections_per_img': batch size input in module S
+    # 'box_score_thresh': negative to accept all predictions
+    covid_mask_net_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': roi_batch_size,
+                           'box_nms_thresh': roi_nms, 'box_score_thresh': -0.01, 'rpn_nms_thresh': rpn_nms}
+
+    print(covid_mask_net_args)
+    # extract anchor generator from the checkpoint
+    sizes = ckpt['anchor_generator'].sizes
+    aspect_ratios = ckpt['anchor_generator'].aspect_ratios
+    anchor_generator = AnchorGenerator(sizes, aspect_ratios)
+    # Faster R-CNN interfaces, masks not implemented at this stage
+    box_head = TwoMLPHead(in_channels=256*7*7, representation_size=128)
+    box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
+    # Mask prediction is not necessary, keep it for future extensions
+    covid_mask_net_args['rpn_anchor_generator'] = anchor_generator
+    covid_mask_net_args['box_predictor'] = box_predictor
+    covid_mask_net_args['box_head'] = box_head
+    # representation size of the S classification module
+    # these should be provided in the config
+    covid_mask_net_args['s_representation_size'] = s_features
+    # Instance of the model, copy weights
+    covid_mask_net_model = mask_net.fasterrcnn_resnet_fpn(backbone_name, truncation, **covid_mask_net_args)
+    covid_mask_net_model.load_state_dict(ckpt['model_weights'])
+    covid_mask_net_model.eval().to(device)
+    print(covid_mask_net_model)
+    # confusion matrix
+    confusion_matrix = torch.zeros(3, 3, dtype=torch.int32).to(device)
+
+    for idx, f in enumerate(os.listdir(test_data_dir)):
+        step(f, covid_mask_net_model, test_data_dir, device, confusion_matrix)
+
+    print("------------------------------------------")
+    print("Confusion Matrix for 3-class problem:")
+    print("0: Control, 1: Normal Pneumonia, 2: COVID")
+    print(confusion_matrix)
+    print("------------------------------------------")
+    # confusion matrix
+    cm = confusion_matrix.float()
+    cm[0, :].div_(cm[0, :].sum())
+    cm[1, :].div_(cm[1, :].sum())
+    cm[2, :].div_(cm[2, :].sum())
+    print("------------------------------------------")
+    print("Class Sensitivity:")
+    print(cm)
+    print("------------------------------------------")
+    print('Overall accuracy:')
+    print(confusion_matrix.diag().float().sum().div(confusion_matrix.sum()))
+    end_time = time.time()
+    print("Evaluation took {0:.1f} seconds".format(end_time - start_time))
+
+def test_step(im_input, model, source_dir, device, c_matrix):
+    # CNCB NCOV datasets: the first integer is the correct class:
+    # 0: control
+    # 1: pneumonia
+    # 2: COVID
+    # extract the correct class from the file name
+    correct_class = int(im_input.split('/')[-1].split('_')[0])
+    im = PILImage.open(os.path.join(source_dir, im_input))
+    if im.mode != 'RGB':
+        im = im.convert(mode='RGB')
+    # get rid of alpha channel
+    img = np.array(im)
+    # print(img)
+    if img.shape[2] > 3:
+        img = img[:, :, :3]
+
+    t_ = transforms.Compose([
+        transforms.ToPILImage(),
+        transforms.Resize(512),
+        transforms.ToTensor()])
+    img = t_(img)
+    if device == torch.device('cuda'):
+        img = img.to(device)
+    out = model([img])
+    pred_class = out[0]['final_scores'].argmax().item()
+    # get confusion matrix
+    c_matrix[correct_class, pred_class] += 1
+
+
+# run the inference
+if __name__ == '__main__':
+    config_test = config.get_config_pars_classifier("test")
+    main(config_test, test_step)