[637b40]: / notebooks / mask_entropy_dice_metric_relations.py

Download this file

61 lines (45 with data), 1.6 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""Check prediction entropy and dice score correlation"""
# %%
import os
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import torch
# enable lib loading even if not installed as a pip package or in PYTHONPATH
# also convenient for relative paths in example config files
from pathlib import Path
os.chdir(Path(__file__).resolve().parent.parent)
from adpkd_segmentation.config.config_utils import get_object_instance # noqa
from adpkd_segmentation.evaluate import validate # noqa
from adpkd_segmentation.utils.train_utils import load_model_data # noqa
sns.set()
# %%
CONFIG = "experiments/september02/random_split_new_data_less_albu_10_more/val/val.yaml" # noqa
with open(CONFIG, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
# %%
model_config = config["_MODEL_CONFIG"]
dataloader_config = config["_VAL_DATALOADER_CONFIG"]
losses_config = config["_LOSSES_METRICS_CONFIG"]
saved_checkpoint = config["_MODEL_CHECKPOINT"]
# override
dataloader_config["batchsize"] = 1
# %%
model = get_object_instance(model_config)()
load_model_data(saved_checkpoint, model, True)
dataloader = get_object_instance(dataloader_config)()
loss_metric = get_object_instance(losses_config)()
# %%
device = torch.device("cuda:0")
model = model.to(device)
model.eval()
# %%
averaged, all_losses_and_metrics = validate(
dataloader, model, loss_metric, device, output_losses_list=True)
# %%
dice_scores = all_losses_and_metrics["dice_metric"]
entropy = all_losses_and_metrics["prediction_entropy"]
plt.xlabel("Prediction entropy")
plt.ylabel("Dice score")
sns.scatterplot(x=entropy, y=dice_scores)
# %%