[637b40]: / adpkd_segmentation / evaluate.py

Download this file

192 lines (158 with data), 6.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
Model evaluation script
python -m adpkd_segmentation.evaluate --config path_to_config_yaml --makelinks
If using a specific GPU (e.g. device 2):
CUDA_VISIBLE_DEVICES=2 python -m evaluate --config path_to_config_yaml
The makelinks flag is needed only once to create symbolic links to the data.
"""
# %%
import argparse
import json
import os
from collections import defaultdict
import torch
import yaml
from matplotlib import pyplot as plt
from adpkd_segmentation.config.config_utils import get_object_instance
from adpkd_segmentation.data.link_data import makelinks
from adpkd_segmentation.data.data_utils import masks_to_colorimg
from adpkd_segmentation.data.data_utils import tensor_dict_to_device
from adpkd_segmentation.utils.train_utils import load_model_data
# %%
def validate(
dataloader,
model,
loss_metric,
device,
plotting_func=None,
plotting_dict=None,
writer=None,
global_step=None,
val_metric_to_check=None,
output_losses_list=False,
):
all_losses_and_metrics = defaultdict(list)
num_examples = 0
output_example_idx = (
hasattr(dataloader.dataset, "output_idx")
and dataloader.dataset.output_idx
)
for batch_idx, output in enumerate(dataloader):
if output_example_idx:
x_batch, y_batch, index = output
extra_dict = dataloader.dataset.get_extra_dict(index)
extra_dict = tensor_dict_to_device(extra_dict, device)
else:
x_batch, y_batch = output
extra_dict = None
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
batch_size = y_batch.size(0)
num_examples += batch_size
with torch.no_grad():
y_batch_hat = model(x_batch)
losses_and_metrics = loss_metric(y_batch_hat, y_batch, extra_dict)
for key, value in losses_and_metrics.items():
all_losses_and_metrics[key].append(value.item() * batch_size)
if plotting_dict is not None and batch_idx in plotting_dict:
# TODO: add support for softmax processing
prediction = torch.sigmoid(y_batch_hat)
image_idx = plotting_dict[batch_idx]
global_im_index = batch_idx * batch_size + image_idx
extra_dict = dataloader.dataset.get_extra_dict(
[global_im_index]
)
extra_dict = tensor_dict_to_device(extra_dict, device)
plotting_func(
writer=writer,
batch=x_batch,
prediction=prediction,
target=y_batch,
global_step=global_step,
idx=image_idx,
title="val_batch_{}_image_{}".format(batch_idx, image_idx),
)
# check DSC metric for this image
# `loss_metric` expects raw model outputs without the sigmoid
im_pred = y_batch_hat[image_idx].unsqueeze(0)
im_target_mask = y_batch[image_idx].unsqueeze(0)
im_losses = loss_metric(im_pred, im_target_mask, extra_dict)
writer.add_scalar(
"val_batch_{}_image_{}_{}".format(
batch_idx, image_idx, val_metric_to_check
),
im_losses[val_metric_to_check],
global_step,
)
averaged = {}
for key, value in all_losses_and_metrics.items():
averaged[key] = sum(all_losses_and_metrics[key]) / num_examples
if output_losses_list:
return averaged, all_losses_and_metrics
return averaged
# %%
def evaluate(config):
model_config = config["_MODEL_CONFIG"]
loader_to_eval = config["_LOADER_TO_EVAL"]
dataloader_config = config[loader_to_eval]
loss_metric_config = config["_LOSSES_METRICS_CONFIG"]
results_path = config["_RESULTS_PATH"]
saved_checkpoint = config["_MODEL_CHECKPOINT"]
checkpoint_format = config["_NEW_CKP_FORMAT"]
model = get_object_instance(model_config)()
if saved_checkpoint is not None:
load_model_data(saved_checkpoint, model, new_format=checkpoint_format)
dataloader = get_object_instance(dataloader_config)()
loss_metric = get_object_instance(loss_metric_config)()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
all_losses_and_metrics = validate(dataloader, model, loss_metric, device)
os.makedirs(results_path)
with open("{}/val_results.json".format(results_path), "w") as fp:
print(all_losses_and_metrics)
json.dump(all_losses_and_metrics, fp, indent=4)
# plotting check
output_example_idx = (
hasattr(dataloader.dataset, "output_idx")
and dataloader.dataset.output_idx
)
data_iter = iter(dataloader)
if output_example_idx:
inputs, labels, _ = next(data_iter)
else:
inputs, labels = next(data_iter)
inputs = inputs.to(device)
preds = model(inputs)
inputs = inputs.cpu()
preds = preds.cpu()
plot_figure_from_batch(inputs, preds)
# %%
def plot_figure_from_batch(inputs, preds, target=None, idx=0):
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(inputs[idx][1], cmap="gray")
axarr[1].imshow(inputs[idx][1], cmap="gray") # background for mask
axarr[1].imshow(masks_to_colorimg(preds[idx]), alpha=0.5)
return f
# %%
def quick_check(config_path, run_makelinks=False):
if run_makelinks:
makelinks()
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
evaluate(config)
# %%
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", help="YAML config path", type=str, required=True
)
parser.add_argument(
"--makelinks", help="Make data links", action="store_true"
)
args = parser.parse_args()
with open(args.config, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if args.makelinks:
makelinks()
evaluate(config)