Diff of /src/evaluation.py [000000] .. [f45789]

Switch to side-by-side view

--- a
+++ b/src/evaluation.py
@@ -0,0 +1,141 @@
+import sys
+sys.path.append('.')
+from tqdm import tqdm
+import torch
+from torch.nn import functional as F
+from .metrics import roc_auc, pr_auc, calc_metrics
+import os
+import yaml
+import numpy as np
+
+
+def evaluate(conf):
+    device = conf['device']
+    dataloader = conf['dataloaders']['test']
+    experiment_dir = conf['experiment_dir']
+    classes = conf['data']['classes']
+    batch_size = conf['data']['batch_size']
+
+    model = conf['model']
+    model.load_state_dict(conf['best_weights'])
+    model = model.to(device)
+    model.eval()
+
+
+    ground_truth = None
+    inferences = None
+
+    batch_bar = tqdm(dataloader, desc='Batch', unit='batches', leave=True)
+    for inputs, labels in batch_bar:
+        inputs = inputs.to(device)
+        with torch.set_grad_enabled(False):
+            outputs = model(inputs)
+
+        probs = F.softmax(outputs, dim=1)[:, 1]
+        probs = probs.cpu()
+        if ground_truth is None and inferences is None:
+            ground_truth = labels
+            inferences = probs
+        else:
+            ground_truth = torch.cat((ground_truth, labels))
+            inferences = torch.cat((inferences, probs))
+
+        # Calculate save metrics
+
+    metrics = {'metrics0.5': calc_metrics(ground_truth=ground_truth,
+                                          inferences=inferences,
+                                          threshold=0.5),
+               'metrics0.7': calc_metrics(ground_truth=ground_truth,
+                                          inferences=inferences,
+                                          threshold=0.7),
+               'metrics0.9': calc_metrics(ground_truth=ground_truth,
+                                          inferences=inferences,
+                                          threshold=0.9),
+               'roc_auc': roc_auc(ground_truth=ground_truth,
+                                  inferences=inferences,
+                                  experiment_dir=experiment_dir),
+               'pr_auc': pr_auc(ground_truth=ground_truth,
+                                inferences=inferences,
+                                experiment_dir=experiment_dir),
+                }
+
+    patients_dataset = conf['patients_dataset']
+    test_patients = patients_dataset.test_patients
+    patients_bar = tqdm(test_patients.items(),
+                        desc='Patient', unit='patients', leave=True)
+
+    inferences = {1: [], 2:[], 3: [], 4: [], 5: [],
+                 '5%':[], '7.5%': [], '10%': []}
+    ground_truth = []
+    for patient, patient_data in patients_bar:
+        patient_IH = patient_data['IH'] # If the patient has IH or not
+        slices = patient_data['slices_IH'] + patient_data['slices_noIH']
+        slices_with_IH = 0
+        samples = []
+        for slice_id in slices:
+            sample, _ = patients_dataset.getSlice(slice_id)
+            samples.append(sample)
+
+        for i in range(0, len(samples), batch_size):
+            batch = torch.stack(samples[i:i+batch_size], dim=0)
+            batch = batch.to(device)
+            with torch.set_grad_enabled(False):
+                outputs = model(batch)
+            IH_probs = F.softmax(outputs, dim=1)[:, 1]
+            slices_with_IH += (IH_probs > 0.8).sum()
+        for num_IH_threshold in [1,2,3,4,5]:
+            net_IH_prediction = slices_with_IH >= num_IH_threshold
+            inferences[num_IH_threshold].append(net_IH_prediction)
+        for percentage, key in [(0.05, '5%'), (0.075, '7.5%'), (0.10, '10%')]:
+            num_IH_threshold = max(1, round(percentage * len(slices)))
+            net_IH_prediction = slices_with_IH >= num_IH_threshold
+            inferences[key].append(net_IH_prediction)
+
+        ground_truth.append(patient_IH)
+
+
+    ground_truth = np.array(ground_truth).astype(float)
+    inferences1 = np.array(inferences[1]).astype(float)
+    inferences2 = np.array(inferences[2]).astype(float)
+    inferences3 = np.array(inferences[3]).astype(float)
+    inferences4 = np.array(inferences[4]).astype(float)
+    inferences5 = np.array(inferences[5]).astype(float)
+    inferencesperc1 = np.array(inferences['5%']).astype(float)
+    inferencesperc2 = np.array(inferences['7.5%']).astype(float)
+    inferencesperc3 = np.array(inferences['10%']).astype(float)
+    metrics['patients_metrics (>= 1 IH slice)'] = calc_metrics(
+                                                    ground_truth=ground_truth,
+                                                    inferences=inferences1)
+    metrics['patients_metrics (>= 2 IH slice)'] = calc_metrics(
+                                                    ground_truth=ground_truth,
+                                                    inferences=inferences2)
+    metrics['patients_metrics (>= 3 IH slice)'] = calc_metrics(
+                                                    ground_truth=ground_truth,
+                                                    inferences=inferences3)
+    metrics['patients_metrics (>= 4 IH slice)'] = calc_metrics(
+                                                    ground_truth=ground_truth,
+                                                    inferences=inferences4)
+    metrics['patients_metrics (>= 5 IH slice)'] = calc_metrics(
+                                                    ground_truth=ground_truth,
+                                                    inferences=inferences5)
+    metrics['patients_metrics (>= 5% IH slice)'] = calc_metrics(
+                                                    ground_truth=ground_truth,
+                                                    inferences=inferencesperc1)
+    metrics['patients_metrics (>= 7.5% IH slice)'] = calc_metrics(
+                                                    ground_truth=ground_truth,
+                                                    inferences=inferencesperc2)
+    metrics['patients_metrics (>= 10% IH slice)'] = calc_metrics(
+                                                    ground_truth=ground_truth,
+                                                    inferences=inferencesperc3)
+
+    del metrics['patients_metrics (>= 1 IH slice)']['threshold']
+    del metrics['patients_metrics (>= 2 IH slice)']['threshold']
+    del metrics['patients_metrics (>= 3 IH slice)']['threshold']
+    del metrics['patients_metrics (>= 4 IH slice)']['threshold']
+    del metrics['patients_metrics (>= 5 IH slice)']['threshold']
+    del metrics['patients_metrics (>= 5% IH slice)']['threshold']
+    del metrics['patients_metrics (>= 7.5% IH slice)']['threshold']
+    del metrics['patients_metrics (>= 10% IH slice)']['threshold']
+
+    with open(os.path.join(experiment_dir, 'metrics.yaml'), 'w') as fp:
+        yaml.dump(metrics, fp)
\ No newline at end of file