Diff of /src/evaluation.py [000000] .. [f45789]

Switch to unified view

a b/src/evaluation.py
1
import sys
2
sys.path.append('.')
3
from tqdm import tqdm
4
import torch
5
from torch.nn import functional as F
6
from .metrics import roc_auc, pr_auc, calc_metrics
7
import os
8
import yaml
9
import numpy as np
10
11
12
def evaluate(conf):
13
    device = conf['device']
14
    dataloader = conf['dataloaders']['test']
15
    experiment_dir = conf['experiment_dir']
16
    classes = conf['data']['classes']
17
    batch_size = conf['data']['batch_size']
18
19
    model = conf['model']
20
    model.load_state_dict(conf['best_weights'])
21
    model = model.to(device)
22
    model.eval()
23
24
25
    ground_truth = None
26
    inferences = None
27
28
    batch_bar = tqdm(dataloader, desc='Batch', unit='batches', leave=True)
29
    for inputs, labels in batch_bar:
30
        inputs = inputs.to(device)
31
        with torch.set_grad_enabled(False):
32
            outputs = model(inputs)
33
34
        probs = F.softmax(outputs, dim=1)[:, 1]
35
        probs = probs.cpu()
36
        if ground_truth is None and inferences is None:
37
            ground_truth = labels
38
            inferences = probs
39
        else:
40
            ground_truth = torch.cat((ground_truth, labels))
41
            inferences = torch.cat((inferences, probs))
42
43
        # Calculate save metrics
44
45
    metrics = {'metrics0.5': calc_metrics(ground_truth=ground_truth,
46
                                          inferences=inferences,
47
                                          threshold=0.5),
48
               'metrics0.7': calc_metrics(ground_truth=ground_truth,
49
                                          inferences=inferences,
50
                                          threshold=0.7),
51
               'metrics0.9': calc_metrics(ground_truth=ground_truth,
52
                                          inferences=inferences,
53
                                          threshold=0.9),
54
               'roc_auc': roc_auc(ground_truth=ground_truth,
55
                                  inferences=inferences,
56
                                  experiment_dir=experiment_dir),
57
               'pr_auc': pr_auc(ground_truth=ground_truth,
58
                                inferences=inferences,
59
                                experiment_dir=experiment_dir),
60
                }
61
62
    patients_dataset = conf['patients_dataset']
63
    test_patients = patients_dataset.test_patients
64
    patients_bar = tqdm(test_patients.items(),
65
                        desc='Patient', unit='patients', leave=True)
66
67
    inferences = {1: [], 2:[], 3: [], 4: [], 5: [],
68
                 '5%':[], '7.5%': [], '10%': []}
69
    ground_truth = []
70
    for patient, patient_data in patients_bar:
71
        patient_IH = patient_data['IH'] # If the patient has IH or not
72
        slices = patient_data['slices_IH'] + patient_data['slices_noIH']
73
        slices_with_IH = 0
74
        samples = []
75
        for slice_id in slices:
76
            sample, _ = patients_dataset.getSlice(slice_id)
77
            samples.append(sample)
78
79
        for i in range(0, len(samples), batch_size):
80
            batch = torch.stack(samples[i:i+batch_size], dim=0)
81
            batch = batch.to(device)
82
            with torch.set_grad_enabled(False):
83
                outputs = model(batch)
84
            IH_probs = F.softmax(outputs, dim=1)[:, 1]
85
            slices_with_IH += (IH_probs > 0.8).sum()
86
        for num_IH_threshold in [1,2,3,4,5]:
87
            net_IH_prediction = slices_with_IH >= num_IH_threshold
88
            inferences[num_IH_threshold].append(net_IH_prediction)
89
        for percentage, key in [(0.05, '5%'), (0.075, '7.5%'), (0.10, '10%')]:
90
            num_IH_threshold = max(1, round(percentage * len(slices)))
91
            net_IH_prediction = slices_with_IH >= num_IH_threshold
92
            inferences[key].append(net_IH_prediction)
93
94
        ground_truth.append(patient_IH)
95
96
97
    ground_truth = np.array(ground_truth).astype(float)
98
    inferences1 = np.array(inferences[1]).astype(float)
99
    inferences2 = np.array(inferences[2]).astype(float)
100
    inferences3 = np.array(inferences[3]).astype(float)
101
    inferences4 = np.array(inferences[4]).astype(float)
102
    inferences5 = np.array(inferences[5]).astype(float)
103
    inferencesperc1 = np.array(inferences['5%']).astype(float)
104
    inferencesperc2 = np.array(inferences['7.5%']).astype(float)
105
    inferencesperc3 = np.array(inferences['10%']).astype(float)
106
    metrics['patients_metrics (>= 1 IH slice)'] = calc_metrics(
107
                                                    ground_truth=ground_truth,
108
                                                    inferences=inferences1)
109
    metrics['patients_metrics (>= 2 IH slice)'] = calc_metrics(
110
                                                    ground_truth=ground_truth,
111
                                                    inferences=inferences2)
112
    metrics['patients_metrics (>= 3 IH slice)'] = calc_metrics(
113
                                                    ground_truth=ground_truth,
114
                                                    inferences=inferences3)
115
    metrics['patients_metrics (>= 4 IH slice)'] = calc_metrics(
116
                                                    ground_truth=ground_truth,
117
                                                    inferences=inferences4)
118
    metrics['patients_metrics (>= 5 IH slice)'] = calc_metrics(
119
                                                    ground_truth=ground_truth,
120
                                                    inferences=inferences5)
121
    metrics['patients_metrics (>= 5% IH slice)'] = calc_metrics(
122
                                                    ground_truth=ground_truth,
123
                                                    inferences=inferencesperc1)
124
    metrics['patients_metrics (>= 7.5% IH slice)'] = calc_metrics(
125
                                                    ground_truth=ground_truth,
126
                                                    inferences=inferencesperc2)
127
    metrics['patients_metrics (>= 10% IH slice)'] = calc_metrics(
128
                                                    ground_truth=ground_truth,
129
                                                    inferences=inferencesperc3)
130
131
    del metrics['patients_metrics (>= 1 IH slice)']['threshold']
132
    del metrics['patients_metrics (>= 2 IH slice)']['threshold']
133
    del metrics['patients_metrics (>= 3 IH slice)']['threshold']
134
    del metrics['patients_metrics (>= 4 IH slice)']['threshold']
135
    del metrics['patients_metrics (>= 5 IH slice)']['threshold']
136
    del metrics['patients_metrics (>= 5% IH slice)']['threshold']
137
    del metrics['patients_metrics (>= 7.5% IH slice)']['threshold']
138
    del metrics['patients_metrics (>= 10% IH slice)']['threshold']
139
140
    with open(os.path.join(experiment_dir, 'metrics.yaml'), 'w') as fp:
141
        yaml.dump(metrics, fp)