Switch to unified view

a b/adpkd_segmentation/evaluate.py
1
"""
2
Model evaluation script
3
4
python -m adpkd_segmentation.evaluate --config path_to_config_yaml --makelinks
5
6
If using a specific GPU (e.g. device 2):
7
CUDA_VISIBLE_DEVICES=2 python -m evaluate --config path_to_config_yaml
8
9
The makelinks flag is needed only once to create symbolic links to the data.
10
"""
11
12
# %%
13
import argparse
14
import json
15
import os
16
from collections import defaultdict
17
18
import torch
19
import yaml
20
from matplotlib import pyplot as plt
21
22
from adpkd_segmentation.config.config_utils import get_object_instance
23
from adpkd_segmentation.data.link_data import makelinks
24
from adpkd_segmentation.data.data_utils import masks_to_colorimg
25
from adpkd_segmentation.data.data_utils import tensor_dict_to_device
26
from adpkd_segmentation.utils.train_utils import load_model_data
27
28
29
# %%
30
def validate(
31
    dataloader,
32
    model,
33
    loss_metric,
34
    device,
35
    plotting_func=None,
36
    plotting_dict=None,
37
    writer=None,
38
    global_step=None,
39
    val_metric_to_check=None,
40
    output_losses_list=False,
41
):
42
    all_losses_and_metrics = defaultdict(list)
43
    num_examples = 0
44
    output_example_idx = (
45
        hasattr(dataloader.dataset, "output_idx")
46
        and dataloader.dataset.output_idx
47
    )
48
49
    for batch_idx, output in enumerate(dataloader):
50
        if output_example_idx:
51
            x_batch, y_batch, index = output
52
            extra_dict = dataloader.dataset.get_extra_dict(index)
53
            extra_dict = tensor_dict_to_device(extra_dict, device)
54
        else:
55
            x_batch, y_batch = output
56
            extra_dict = None
57
        x_batch = x_batch.to(device)
58
        y_batch = y_batch.to(device)
59
        batch_size = y_batch.size(0)
60
        num_examples += batch_size
61
        with torch.no_grad():
62
            y_batch_hat = model(x_batch)
63
            losses_and_metrics = loss_metric(y_batch_hat, y_batch, extra_dict)
64
65
            for key, value in losses_and_metrics.items():
66
                all_losses_and_metrics[key].append(value.item() * batch_size)
67
68
            if plotting_dict is not None and batch_idx in plotting_dict:
69
                # TODO: add support for softmax processing
70
                prediction = torch.sigmoid(y_batch_hat)
71
                image_idx = plotting_dict[batch_idx]
72
                global_im_index = batch_idx * batch_size + image_idx
73
                extra_dict = dataloader.dataset.get_extra_dict(
74
                    [global_im_index]
75
                )
76
                extra_dict = tensor_dict_to_device(extra_dict, device)
77
                plotting_func(
78
                    writer=writer,
79
                    batch=x_batch,
80
                    prediction=prediction,
81
                    target=y_batch,
82
                    global_step=global_step,
83
                    idx=image_idx,
84
                    title="val_batch_{}_image_{}".format(batch_idx, image_idx),
85
                )
86
                # check DSC metric for this image
87
                # `loss_metric` expects raw model outputs without the sigmoid
88
                im_pred = y_batch_hat[image_idx].unsqueeze(0)
89
                im_target_mask = y_batch[image_idx].unsqueeze(0)
90
                im_losses = loss_metric(im_pred, im_target_mask, extra_dict)
91
                writer.add_scalar(
92
                    "val_batch_{}_image_{}_{}".format(
93
                        batch_idx, image_idx, val_metric_to_check
94
                    ),
95
                    im_losses[val_metric_to_check],
96
                    global_step,
97
                )
98
99
    averaged = {}
100
    for key, value in all_losses_and_metrics.items():
101
        averaged[key] = sum(all_losses_and_metrics[key]) / num_examples
102
103
    if output_losses_list:
104
        return averaged, all_losses_and_metrics
105
    return averaged
106
107
108
# %%
109
def evaluate(config):
110
    model_config = config["_MODEL_CONFIG"]
111
    loader_to_eval = config["_LOADER_TO_EVAL"]
112
    dataloader_config = config[loader_to_eval]
113
    loss_metric_config = config["_LOSSES_METRICS_CONFIG"]
114
    results_path = config["_RESULTS_PATH"]
115
    saved_checkpoint = config["_MODEL_CHECKPOINT"]
116
    checkpoint_format = config["_NEW_CKP_FORMAT"]
117
118
    model = get_object_instance(model_config)()
119
    if saved_checkpoint is not None:
120
        load_model_data(saved_checkpoint, model, new_format=checkpoint_format)
121
122
    dataloader = get_object_instance(dataloader_config)()
123
    loss_metric = get_object_instance(loss_metric_config)()
124
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
125
126
    model = model.to(device)
127
    model.eval()
128
    all_losses_and_metrics = validate(dataloader, model, loss_metric, device)
129
130
    os.makedirs(results_path)
131
    with open("{}/val_results.json".format(results_path), "w") as fp:
132
        print(all_losses_and_metrics)
133
        json.dump(all_losses_and_metrics, fp, indent=4)
134
135
    # plotting check
136
    output_example_idx = (
137
        hasattr(dataloader.dataset, "output_idx")
138
        and dataloader.dataset.output_idx
139
    )
140
    data_iter = iter(dataloader)
141
    if output_example_idx:
142
        inputs, labels, _ = next(data_iter)
143
    else:
144
        inputs, labels = next(data_iter)
145
146
    inputs = inputs.to(device)
147
    preds = model(inputs)
148
    inputs = inputs.cpu()
149
    preds = preds.cpu()
150
151
    plot_figure_from_batch(inputs, preds)
152
153
154
# %%
155
def plot_figure_from_batch(inputs, preds, target=None, idx=0):
156
157
    f, axarr = plt.subplots(1, 2)
158
    axarr[0].imshow(inputs[idx][1], cmap="gray")
159
    axarr[1].imshow(inputs[idx][1], cmap="gray")  # background for mask
160
    axarr[1].imshow(masks_to_colorimg(preds[idx]), alpha=0.5)
161
162
    return f
163
164
165
# %%
166
def quick_check(config_path, run_makelinks=False):
167
    if run_makelinks:
168
        makelinks()
169
    with open(config_path, "r") as f:
170
        config = yaml.load(f, Loader=yaml.FullLoader)
171
    evaluate(config)
172
173
174
# %%
175
if __name__ == "__main__":
176
    parser = argparse.ArgumentParser()
177
    parser.add_argument(
178
        "--config", help="YAML config path", type=str, required=True
179
    )
180
    parser.add_argument(
181
        "--makelinks", help="Make data links", action="store_true"
182
    )
183
184
    args = parser.parse_args()
185
    with open(args.config, "r") as f:
186
        config = yaml.load(f, Loader=yaml.FullLoader)
187
188
    if args.makelinks:
189
        makelinks()
190
191
    evaluate(config)