Switch to unified view

a b/adpkd_segmentation/evaluate_patients.py
1
"""
2
Model evaluation script for TKV
3
4
python -m adpkd_segmentation.evaluate_patients
5
--config path_to_config_yaml --makelinks --out_path output_csv_path
6
7
If using a specific GPU, e.g. device 2, prepend the command with CUDA_VISIBLE_DEVICES=2 # noqa
8
9
The makelinks flag is needed only once to create symbolic links to the data.
10
"""
11
12
# %%
13
from collections import OrderedDict, defaultdict
14
import argparse
15
16
import yaml
17
import pandas as pd
18
19
import torch
20
21
from adpkd_segmentation.config.config_utils import get_object_instance
22
from adpkd_segmentation.data.link_data import makelinks
23
from adpkd_segmentation.utils.train_utils import load_model_data
24
from adpkd_segmentation.utils.losses import SigmoidBinarize
25
26
27
# %%
28
def calculate_dcm_voxel_volumes(
29
    dataloader, model, device, binarize_func,
30
):
31
    num_examples = 0
32
    dataset = dataloader.dataset
33
    updated_dcm2attribs = {}
34
35
    output_example_idx = (
36
        hasattr(dataloader.dataset, "output_idx")
37
        and dataloader.dataset.output_idx
38
    )
39
40
    for batch_idx, output in enumerate(dataloader):
41
        if output_example_idx:
42
            x_batch, y_batch, _ = output
43
        else:
44
            x_batch, y_batch = output
45
46
        x_batch = x_batch.to(device)
47
        y_batch = y_batch.to(device)
48
        batch_size = y_batch.size(0)
49
        num_examples += batch_size
50
        with torch.no_grad():
51
            y_batch_hat = model(x_batch)
52
            y_batch_hat_binary = binarize_func(y_batch_hat)
53
            start_idx = num_examples - batch_size
54
            end_idx = num_examples
55
56
            for inbatch_idx, dataset_idx in enumerate(
57
                range(start_idx, end_idx)
58
            ):
59
                # calculate TKV and TKV inputs for each dcm
60
                # TODO:
61
                # support 3 channel setups where ones could mean background
62
                # needs mask standardization to single channel
63
                _, dcm_path, attribs = dataset.get_verbose(dataset_idx)
64
                attribs["pred_kidney_pixels"] = torch.sum(
65
                    y_batch_hat_binary[inbatch_idx] > 0
66
                ).item()
67
                attribs["ground_kidney_pixels"] = torch.sum(
68
                    y_batch[inbatch_idx] > 0
69
                ).item()
70
71
                # TODO: Clean up method of accessing Resize transform
72
                attribs["transform_resize_dim"] = (
73
                    dataloader.dataset.augmentation[0].height,
74
                    dataloader.dataset.augmentation[0].width,
75
                )
76
77
                # scale factor takes into account the difference
78
                # between the original image/mask size and the size
79
                # after mask & prediction resizing
80
                scale_factor = (attribs["dim"][0] ** 2) / (
81
                    attribs["transform_resize_dim"][0] ** 2
82
                )
83
                attribs["Vol_GT"] = (
84
                    scale_factor
85
                    * attribs["vox_vol"]
86
                    * attribs["ground_kidney_pixels"]
87
                )
88
                attribs["Vol_Pred"] = (
89
                    scale_factor
90
                    * attribs["vox_vol"]
91
                    * attribs["pred_kidney_pixels"]
92
                )
93
94
                updated_dcm2attribs[dcm_path] = attribs
95
96
    return updated_dcm2attribs
97
98
99
# %%
100
101
def visualize_performance(
102
    dataloader, model, device, binarize_func,
103
):
104
    dataset = dataloader.dataset
105
    output_example_idx = (
106
        hasattr(dataloader.dataset, "output_idx")
107
        and dataloader.dataset.output_idx
108
    )
109
110
    for batch_idx, output in enumerate(dataloader):
111
        if output_example_idx:
112
            x_batch, y_batch, _ = output
113
        else:
114
            x_batch, y_batch = output
115
116
        x_batch = x_batch.to(device)
117
        y_batch = y_batch.to(device)
118
        batch_size = y_batch.size(0)
119
        num_examples += batch_size
120
        with torch.no_grad():
121
            _, dcm_path, attribs = dataset.get_verbose(batch_size * batch_idx)
122
            
123
            y_batch_hat = model(x_batch)
124
            y_batch_hat_binary = binarize_func(y_batch_hat)
125
            start_idx = batch_size * batch_idx
126
            end_idx = batch_size * (1 + batch_idx)
127
128
            # for inbatch_idx, dataset_idx in enumerate(
129
            #     range(start_idx, end_idx)
130
            # ):
131
            #     _, dcm_path, attribs = dataset.get_verbose(dataset_idx)
132
133
            #     updated_dcm2attribs[dcm_path] = attribs
134
135
136
# %%
137
def evaluate(config):
138
    model_config = config["_MODEL_CONFIG"]
139
    loader_to_eval = config["_LOADER_TO_EVAL"]
140
    dataloader_config = config[loader_to_eval]
141
    saved_checkpoint = config["_MODEL_CHECKPOINT"]
142
    checkpoint_format = config["_NEW_CKP_FORMAT"]
143
144
    model = get_object_instance(model_config)()
145
    if saved_checkpoint is not None:
146
        load_model_data(saved_checkpoint, model, new_format=checkpoint_format)
147
148
    dataloader = get_object_instance(dataloader_config)()
149
150
    # TODO: support other metrics as needed
151
    binarize_func = SigmoidBinarize(thresholds=[0.5])
152
153
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
154
    model = model.to(device)
155
    model.eval()
156
157
    updated_dcm2attribs = calculate_dcm_voxel_volumes(
158
        dataloader, model, device, binarize_func
159
    )
160
161
    return updated_dcm2attribs
162
163
164
# %%
165
def calculate_TKVs(config_path, run_makelinks=False, output=None):
166
    if run_makelinks:
167
        makelinks()
168
    with open(config_path, "r") as f:
169
        config = yaml.load(f, Loader=yaml.FullLoader)
170
171
    # val or test
172
    split = config["_LOADER_TO_EVAL"].split("_")[1].lower()
173
174
    dcm2attrib = evaluate(config)
175
176
    patient_MR_TKV = defaultdict(float)
177
    TKV_data = OrderedDict()
178
179
    for key, value in dcm2attrib.items():
180
        patient_MR = value["patient"] + value["MR"]
181
        patient_MR_TKV[(patient_MR, "GT")] += value["Vol_GT"]
182
        patient_MR_TKV[(patient_MR, "Pred")] += value["Vol_Pred"]
183
184
    for key, value in dcm2attrib.items():
185
        patient_MR = value["patient"] + value["MR"]
186
187
        if patient_MR not in TKV_data:
188
189
            summary = {
190
                "TKV_GT": patient_MR_TKV[(patient_MR, "GT")],
191
                "TKV_Pred": patient_MR_TKV[(patient_MR, "Pred")],
192
                "sequence": value["seq"],
193
                "split": split,
194
            }
195
196
            TKV_data[patient_MR] = summary
197
198
    df = pd.DataFrame(TKV_data).transpose()
199
200
    if output is not None:
201
        df.to_csv(output)
202
203
    return TKV_data
204
205
206
# %%
207
if __name__ == "__main__":
208
    parser = argparse.ArgumentParser()
209
    parser.add_argument(
210
        "--config", help="YAML config path", type=str, required=True
211
    )
212
    parser.add_argument(
213
        "--makelinks", help="Make data links", action="store_true"
214
    )
215
    parser.add_argument("--out_path", help="Path to output csv", required=True)
216
217
    args = parser.parse_args()
218
    calculate_TKVs(args.config, args.makelinks, args.out_path)