|
a |
|
b/rocaseg/evaluate.py |
|
|
1 |
import os |
|
|
2 |
import logging |
|
|
3 |
from glob import glob |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
from skimage.color import label2rgb |
|
|
7 |
from skimage import img_as_ubyte |
|
|
8 |
from tqdm import tqdm |
|
|
9 |
import click |
|
|
10 |
|
|
|
11 |
import cv2 |
|
|
12 |
import tifffile |
|
|
13 |
import torch |
|
|
14 |
import torch.nn as nn |
|
|
15 |
from torch.utils.data.dataloader import DataLoader |
|
|
16 |
|
|
|
17 |
from rocaseg.datasets import sources_from_path |
|
|
18 |
from rocaseg.components import CheckpointHandler |
|
|
19 |
from rocaseg.components.formats import numpy_to_nifti, png_to_numpy |
|
|
20 |
from rocaseg.models import dict_models |
|
|
21 |
from rocaseg.preproc import * |
|
|
22 |
from rocaseg.repro import set_ultimate_seed |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
# The fix is a workaround to PyTorch multiprocessing issue: |
|
|
26 |
# "RuntimeError: received 0 items of ancdata" |
|
|
27 |
torch.multiprocessing.set_sharing_strategy('file_system') |
|
|
28 |
|
|
|
29 |
cv2.ocl.setUseOpenCL(False) |
|
|
30 |
cv2.setNumThreads(0) |
|
|
31 |
|
|
|
32 |
logging.basicConfig() |
|
|
33 |
logger = logging.getLogger('eval') |
|
|
34 |
logger.setLevel(logging.INFO) |
|
|
35 |
|
|
|
36 |
set_ultimate_seed() |
|
|
37 |
|
|
|
38 |
if torch.cuda.is_available(): |
|
|
39 |
maybe_gpu = 'cuda' |
|
|
40 |
else: |
|
|
41 |
maybe_gpu = 'cpu' |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def predict_folds(config, loader, fold_idcs): |
|
|
45 |
"""Evaluate the model versus each fold |
|
|
46 |
""" |
|
|
47 |
for fold_idx in fold_idcs: |
|
|
48 |
paths_weights_fold = dict() |
|
|
49 |
paths_weights_fold['segm'] = \ |
|
|
50 |
os.path.join(config['path_weights'], 'segm', f'fold_{fold_idx}') |
|
|
51 |
|
|
|
52 |
handlers_ckpt = dict() |
|
|
53 |
handlers_ckpt['segm'] = CheckpointHandler(paths_weights_fold['segm']) |
|
|
54 |
|
|
|
55 |
paths_ckpt_sel = dict() |
|
|
56 |
paths_ckpt_sel['segm'] = handlers_ckpt['segm'].get_last_ckpt() |
|
|
57 |
|
|
|
58 |
# Initialize and configure the model |
|
|
59 |
model = (dict_models[config['model_segm']] |
|
|
60 |
(input_channels=config['input_channels'], |
|
|
61 |
output_channels=config['output_channels'], |
|
|
62 |
center_depth=config['center_depth'], |
|
|
63 |
pretrained=config['pretrained'], |
|
|
64 |
restore_weights=config['restore_weights'], |
|
|
65 |
path_weights=paths_ckpt_sel['segm'])) |
|
|
66 |
model = nn.DataParallel(model).to(maybe_gpu) |
|
|
67 |
model.eval() |
|
|
68 |
|
|
|
69 |
with tqdm(total=len(loader), desc=f'Eval, fold {fold_idx}') as prog_bar: |
|
|
70 |
for i, data_batch in enumerate(loader): |
|
|
71 |
xs, ys_true = data_batch['xs'], data_batch['ys'] |
|
|
72 |
xs, ys_true = xs.to(maybe_gpu), ys_true.to(maybe_gpu) |
|
|
73 |
|
|
|
74 |
if config['model_segm'] == 'unet_lext': |
|
|
75 |
ys_pred = model(xs) |
|
|
76 |
elif config['model_segm'] == 'unet_lext_aux': |
|
|
77 |
ys_pred, _ = model(xs) |
|
|
78 |
else: |
|
|
79 |
msg = f"Unknown model {config['model_segm']}" |
|
|
80 |
raise ValueError(msg) |
|
|
81 |
|
|
|
82 |
ys_pred_softmax = nn.Softmax(dim=1)(ys_pred) |
|
|
83 |
ys_pred_softmax_np = ys_pred_softmax.detach().to('cpu').numpy() |
|
|
84 |
|
|
|
85 |
data_batch['pred_softmax'] = ys_pred_softmax_np |
|
|
86 |
|
|
|
87 |
# Rearrange the batch |
|
|
88 |
data_dicts = [{k: v[n] for k, v in data_batch.items()} |
|
|
89 |
for n in range(len(data_batch['image']))] |
|
|
90 |
|
|
|
91 |
for k, data_dict in enumerate(data_dicts): |
|
|
92 |
dir_base = os.path.join( |
|
|
93 |
config['path_predicts'], |
|
|
94 |
data_dict['patient'], data_dict['release'], data_dict['sequence']) |
|
|
95 |
fname_base = os.path.splitext( |
|
|
96 |
os.path.basename(data_dict['path_rel_image']))[0] |
|
|
97 |
|
|
|
98 |
# Save the predictions |
|
|
99 |
dir_predicts = os.path.join(dir_base, 'mask_folds') |
|
|
100 |
if not os.path.exists(dir_predicts): |
|
|
101 |
os.makedirs(dir_predicts) |
|
|
102 |
|
|
|
103 |
fname_full = os.path.join( |
|
|
104 |
dir_predicts, |
|
|
105 |
f'{fname_base}_fold_{fold_idx}.tiff') |
|
|
106 |
|
|
|
107 |
tmp = (data_dict['pred_softmax'] * 255).astype(np.uint8, casting='unsafe') |
|
|
108 |
tifffile.imsave(fname_full, tmp, compress=9) |
|
|
109 |
|
|
|
110 |
prog_bar.update(1) |
|
|
111 |
|
|
|
112 |
|
|
|
113 |
def merge_predictions(config, source, loader, dict_fns, |
|
|
114 |
save_plots=False, remove_foldw=False, convert_to_nifti=True): |
|
|
115 |
"""Merge the predictions over all folds |
|
|
116 |
""" |
|
|
117 |
dir_source_root = source['path_root'] |
|
|
118 |
df_meta = loader.dataset.df_meta |
|
|
119 |
|
|
|
120 |
with tqdm(total=len(df_meta), desc='Merge') as prog_bar: |
|
|
121 |
for i, row in df_meta.iterrows(): |
|
|
122 |
dir_scan_predicts = os.path.join( |
|
|
123 |
config['path_predicts'], |
|
|
124 |
row['patient'], row['release'], row['sequence']) |
|
|
125 |
dir_image_prep = os.path.join(dir_scan_predicts, 'image_prep') |
|
|
126 |
dir_mask_prep = os.path.join(dir_scan_predicts, 'mask_prep') |
|
|
127 |
dir_mask_folds = os.path.join(dir_scan_predicts, 'mask_folds') |
|
|
128 |
dir_mask_foldavg = os.path.join(dir_scan_predicts, 'mask_foldavg') |
|
|
129 |
dir_vis_foldavg = os.path.join(dir_scan_predicts, 'vis_foldavg') |
|
|
130 |
|
|
|
131 |
for p in (dir_image_prep, dir_mask_prep, dir_mask_folds, dir_mask_foldavg, |
|
|
132 |
dir_vis_foldavg): |
|
|
133 |
if not os.path.exists(p): |
|
|
134 |
os.makedirs(p) |
|
|
135 |
|
|
|
136 |
# Find the corresponding prediction files |
|
|
137 |
fname_base = os.path.splitext(os.path.basename(row['path_rel_image']))[0] |
|
|
138 |
|
|
|
139 |
fnames_pred = glob(os.path.join(dir_mask_folds, f'{fname_base}_fold_*.*')) |
|
|
140 |
|
|
|
141 |
# Read the reference data |
|
|
142 |
image = cv2.imread( |
|
|
143 |
os.path.join(dir_source_root, row['path_rel_image']), |
|
|
144 |
cv2.IMREAD_GRAYSCALE) |
|
|
145 |
image = dict_fns['crop'](image[None, ])[0] |
|
|
146 |
image = np.squeeze(image) |
|
|
147 |
if 'path_rel_mask' in row.index: |
|
|
148 |
ys_true = loader.dataset.read_mask( |
|
|
149 |
os.path.join(dir_source_root, row['path_rel_mask'])) |
|
|
150 |
if ys_true is not None: |
|
|
151 |
ys_true = dict_fns['crop'](ys_true)[0] |
|
|
152 |
else: |
|
|
153 |
ys_true = None |
|
|
154 |
|
|
|
155 |
# Read the fold-wise predictions |
|
|
156 |
yss_pred = [tifffile.imread(f) for f in fnames_pred] |
|
|
157 |
ys_pred = np.stack(yss_pred, axis=0).astype(np.float32) / 255 |
|
|
158 |
ys_pred = torch.from_numpy(ys_pred).unsqueeze(dim=0) |
|
|
159 |
|
|
|
160 |
# Average the fold predictions |
|
|
161 |
ys_pred = torch.mean(ys_pred, dim=1, keepdim=False) |
|
|
162 |
ys_pred_softmax = ys_pred / torch.sum(ys_pred, dim=1, keepdim=True) |
|
|
163 |
ys_pred_softmax_np = ys_pred_softmax.squeeze().numpy() |
|
|
164 |
|
|
|
165 |
ys_pred_arg_np = ys_pred_softmax_np.argmax(axis=0) |
|
|
166 |
|
|
|
167 |
# Save preprocessed input data |
|
|
168 |
fname_full = os.path.join(dir_image_prep, f'{fname_base}.png') |
|
|
169 |
cv2.imwrite(fname_full, image) # image |
|
|
170 |
|
|
|
171 |
if ys_true is not None: |
|
|
172 |
ys_true = ys_true.astype(np.float32) |
|
|
173 |
ys_true = torch.from_numpy(ys_true).unsqueeze(dim=0) |
|
|
174 |
ys_true_arg_np = ys_true.numpy().squeeze().argmax(axis=0) |
|
|
175 |
fname_full = os.path.join(dir_mask_prep, f'{fname_base}.png') |
|
|
176 |
cv2.imwrite(fname_full, ys_true_arg_np) # mask |
|
|
177 |
|
|
|
178 |
fname_meta = os.path.join(config['path_predicts'], 'meta_dynamic.csv') |
|
|
179 |
if not os.path.exists(fname_meta): |
|
|
180 |
df_meta.to_csv(fname_meta, index=False) # metainfo |
|
|
181 |
|
|
|
182 |
# Save ensemble prediction |
|
|
183 |
fname_full = os.path.join(dir_mask_foldavg, f'{fname_base}.png') |
|
|
184 |
cv2.imwrite(fname_full, ys_pred_arg_np) |
|
|
185 |
|
|
|
186 |
# Save ensemble visualizations |
|
|
187 |
if save_plots: |
|
|
188 |
if ys_true is not None: |
|
|
189 |
fname_full = os.path.join( |
|
|
190 |
dir_vis_foldavg, f"{fname_base}_overlay_mask.png") |
|
|
191 |
save_vis_overlay(image=image, |
|
|
192 |
mask=ys_true_arg_np, |
|
|
193 |
num_classes=config['output_channels'], |
|
|
194 |
fname=fname_full) |
|
|
195 |
|
|
|
196 |
fname_full = os.path.join( |
|
|
197 |
dir_vis_foldavg, f"{fname_base}_overlay_pred.png") |
|
|
198 |
save_vis_overlay(image=image, |
|
|
199 |
mask=ys_pred_arg_np, |
|
|
200 |
num_classes=config['output_channels'], |
|
|
201 |
fname=fname_full) |
|
|
202 |
|
|
|
203 |
if ys_true is not None: |
|
|
204 |
fname_full = os.path.join( |
|
|
205 |
dir_vis_foldavg, f"{fname_base}_overlay_diff.png") |
|
|
206 |
save_vis_mask_diff(image=image, |
|
|
207 |
mask_true=ys_true_arg_np, |
|
|
208 |
mask_pred=ys_pred_arg_np, |
|
|
209 |
fname=fname_full) |
|
|
210 |
|
|
|
211 |
# Remove the fold predictions |
|
|
212 |
if remove_foldw: |
|
|
213 |
for f in fnames_pred: |
|
|
214 |
try: |
|
|
215 |
os.remove(f) |
|
|
216 |
except OSError: |
|
|
217 |
logger.error(f'Cannot remove {f}') |
|
|
218 |
prog_bar.update(1) |
|
|
219 |
|
|
|
220 |
# Convert the results to 3D NIfTI images |
|
|
221 |
if convert_to_nifti: |
|
|
222 |
df_meta = df_meta.sort_values(by=["patient", "release", "sequence", "side"]) |
|
|
223 |
|
|
|
224 |
for gb_name, gb_df in tqdm( |
|
|
225 |
df_meta.groupby(["patient", "release", "sequence", "side"]), |
|
|
226 |
desc="Convert to NIfTI"): |
|
|
227 |
|
|
|
228 |
patient, release, sequence, side = gb_name |
|
|
229 |
spacings = (gb_df['pixel_spacing_0'].iloc[0], |
|
|
230 |
gb_df['pixel_spacing_1'].iloc[0], |
|
|
231 |
gb_df['slice_thickness'].iloc[0]) |
|
|
232 |
|
|
|
233 |
dir_scan_predicts = os.path.join(config['path_predicts'], |
|
|
234 |
patient, release, sequence) |
|
|
235 |
for result in ("image_prep", "mask_prep", "mask_foldavg"): |
|
|
236 |
pattern = os.path.join(dir_scan_predicts, result, '*.png') |
|
|
237 |
path_nii = os.path.join(dir_scan_predicts, f"{result}.nii") |
|
|
238 |
|
|
|
239 |
# Read and compose 3D image |
|
|
240 |
img = png_to_numpy(pattern_fname_in=pattern, reverse=False) |
|
|
241 |
|
|
|
242 |
# Save to NIfTI |
|
|
243 |
numpy_to_nifti(stack=img, fname_out=path_nii, |
|
|
244 |
spacings=spacings, rcp_to_ras=True) |
|
|
245 |
|
|
|
246 |
|
|
|
247 |
def save_vis_overlay(image, mask, num_classes, fname): |
|
|
248 |
# Add a sample of each class to have consistent class colors |
|
|
249 |
mask[0, :num_classes] = list(range(num_classes)) |
|
|
250 |
overlay = label2rgb(label=mask, image=image, bg_label=0, |
|
|
251 |
colors=['orangered', 'gold', 'lime', 'fuchsia']) |
|
|
252 |
# Convert to uint8 to save space |
|
|
253 |
overlay = img_as_ubyte(overlay) |
|
|
254 |
# Save to file |
|
|
255 |
if overlay.ndim == 3: |
|
|
256 |
overlay = overlay[:, :, ::-1] |
|
|
257 |
cv2.imwrite(fname, overlay) |
|
|
258 |
|
|
|
259 |
|
|
|
260 |
def save_vis_mask_diff(image, mask_true, mask_pred, fname): |
|
|
261 |
diff = np.empty_like(mask_true) |
|
|
262 |
diff[(mask_true == mask_pred) & (mask_pred == 0)] = 0 # TN |
|
|
263 |
diff[(mask_true == mask_pred) & (mask_pred != 0)] = 0 # TP |
|
|
264 |
diff[(mask_true != mask_pred) & (mask_pred == 0)] = 2 # FP |
|
|
265 |
diff[(mask_true != mask_pred) & (mask_pred != 0)] = 3 # FN |
|
|
266 |
diff_colors = ('green', 'red', 'yellow') |
|
|
267 |
diff[0, :4] = [0, 1, 2, 3] |
|
|
268 |
overlay = label2rgb(label=diff, image=image, bg_label=0, |
|
|
269 |
colors=diff_colors) |
|
|
270 |
# Convert to uint8 to save space |
|
|
271 |
overlay = img_as_ubyte(overlay) |
|
|
272 |
# Save to file |
|
|
273 |
if overlay.ndim == 3: |
|
|
274 |
overlay = overlay[:, :, ::-1] |
|
|
275 |
cv2.imwrite(fname, overlay) |
|
|
276 |
|
|
|
277 |
|
|
|
278 |
@click.command() |
|
|
279 |
@click.option('--path_data_root', default='../../data') |
|
|
280 |
@click.option('--path_experiment_root', default='../../results/temporary') |
|
|
281 |
@click.option('--model_segm', default='unet_lext') |
|
|
282 |
@click.option('--center_depth', default=1, type=int) |
|
|
283 |
@click.option('--pretrained', is_flag=True) |
|
|
284 |
@click.option('--restore_weights', is_flag=True) |
|
|
285 |
@click.option('--input_channels', default=1, type=int) |
|
|
286 |
@click.option('--output_channels', default=1, type=int) |
|
|
287 |
@click.option('--dataset', type=click.Choice( |
|
|
288 |
['oai_imo', 'okoa', 'maknee'])) |
|
|
289 |
@click.option('--subset', type=click.Choice( |
|
|
290 |
['test', 'all'])) |
|
|
291 |
@click.option('--mask_mode', default='all_unitibial_unimeniscus', type=str) |
|
|
292 |
@click.option('--sample_mode', default='x_y', type=str) |
|
|
293 |
@click.option('--batch_size', default=64, type=int) |
|
|
294 |
@click.option('--fold_num', default=5, type=int) |
|
|
295 |
@click.option('--fold_idx', default=-1, type=int) |
|
|
296 |
@click.option('--fold_idx_ignore', multiple=True, type=int) |
|
|
297 |
@click.option('--num_workers', default=1, type=int) |
|
|
298 |
@click.option('--seed_trainval_test', default=0, type=int) |
|
|
299 |
@click.option('--predict_folds', is_flag=True) |
|
|
300 |
@click.option('--merge_predictions', is_flag=True) |
|
|
301 |
@click.option('--save_plots', is_flag=True) |
|
|
302 |
def main(**config): |
|
|
303 |
config['path_data_root'] = os.path.abspath(config['path_data_root']) |
|
|
304 |
config['path_experiment_root'] = os.path.abspath(config['path_experiment_root']) |
|
|
305 |
|
|
|
306 |
config['path_weights'] = os.path.join(config['path_experiment_root'], 'weights') |
|
|
307 |
if not os.path.exists(config['path_weights']): |
|
|
308 |
raise ValueError('{} does not exist'.format(config['path_weights'])) |
|
|
309 |
|
|
|
310 |
config['path_predicts'] = os.path.join( |
|
|
311 |
config['path_experiment_root'], f"predicts_{config['dataset']}_test") |
|
|
312 |
config['path_logs'] = os.path.join( |
|
|
313 |
config['path_experiment_root'], f"logs_{config['dataset']}_test") |
|
|
314 |
|
|
|
315 |
os.makedirs(config['path_predicts'], exist_ok=True) |
|
|
316 |
os.makedirs(config['path_logs'], exist_ok=True) |
|
|
317 |
|
|
|
318 |
logging_fh = logging.FileHandler( |
|
|
319 |
os.path.join(config['path_logs'], 'main.log')) |
|
|
320 |
logging_fh.setLevel(logging.DEBUG) |
|
|
321 |
logger.addHandler(logging_fh) |
|
|
322 |
|
|
|
323 |
# Collect the available and specified sources |
|
|
324 |
sources = sources_from_path(path_data_root=config['path_data_root'], |
|
|
325 |
selection=config['dataset'], |
|
|
326 |
with_folds=True, |
|
|
327 |
seed_trainval_test=config['seed_trainval_test']) |
|
|
328 |
|
|
|
329 |
# Select the subset for evaluation |
|
|
330 |
if config['subset'] == 'test': |
|
|
331 |
logging.warning('Using the regular trainval-test split') |
|
|
332 |
elif config['subset'] == 'all': |
|
|
333 |
logging.warning('Using data selection: full dataset') |
|
|
334 |
for s in sources: |
|
|
335 |
sources[s]['test_df'] = sources[s]['sel_df'] |
|
|
336 |
logger.info(f"Selected number of samples: {len(sources[s]['test_df'])}") |
|
|
337 |
else: |
|
|
338 |
raise ValueError(f"Unknown dataset: {config['subset']}") |
|
|
339 |
|
|
|
340 |
if config['dataset'] == 'oai_imo': |
|
|
341 |
from rocaseg.datasets import DatasetOAIiMoSagittal2d as DatasetSagittal2d |
|
|
342 |
elif config['dataset'] == 'okoa': |
|
|
343 |
from rocaseg.datasets import DatasetOKOASagittal2d as DatasetSagittal2d |
|
|
344 |
elif config['dataset'] == 'maknee': |
|
|
345 |
from rocaseg.datasets import DatasetMAKNEESagittal2d as DatasetSagittal2d |
|
|
346 |
else: |
|
|
347 |
raise ValueError(f"Unknown dataset: {config['dataset']}") |
|
|
348 |
|
|
|
349 |
# Configure dataset-dependent transforms |
|
|
350 |
fn_crop = CenterCrop(height=300, width=300) |
|
|
351 |
if config['dataset'] == 'oai_imo': |
|
|
352 |
fn_norm = Normalize(mean=0.252699, std=0.251142) |
|
|
353 |
fn_unnorm = UnNormalize(mean=0.252699, std=0.251142) |
|
|
354 |
elif config['dataset'] == 'okoa': |
|
|
355 |
fn_norm = Normalize(mean=0.232454, std=0.236259) |
|
|
356 |
fn_unnorm = UnNormalize(mean=0.232454, std=0.236259) |
|
|
357 |
else: |
|
|
358 |
msg = f"No transforms defined for dataset: {config['dataset']}" |
|
|
359 |
raise NotImplementedError(msg) |
|
|
360 |
dict_fns = {'crop': fn_crop, 'norm': fn_norm, 'unnorm': fn_unnorm} |
|
|
361 |
|
|
|
362 |
dataset_test = DatasetSagittal2d( |
|
|
363 |
df_meta=sources[config['dataset']]['test_df'], mask_mode=config['mask_mode'], |
|
|
364 |
name=config['dataset'], sample_mode=config['sample_mode'], |
|
|
365 |
transforms=[ |
|
|
366 |
PercentileClippingAndToFloat(cut_min=10, cut_max=99), |
|
|
367 |
fn_crop, |
|
|
368 |
fn_norm, |
|
|
369 |
ToTensor() |
|
|
370 |
]) |
|
|
371 |
loader_test = DataLoader(dataset_test, |
|
|
372 |
batch_size=config['batch_size'], |
|
|
373 |
shuffle=False, |
|
|
374 |
num_workers=config['num_workers'], |
|
|
375 |
drop_last=False) |
|
|
376 |
|
|
|
377 |
# Build a list of folds to run on |
|
|
378 |
if config['fold_idx'] == -1: |
|
|
379 |
fold_idcs = list(range(config['fold_num'])) |
|
|
380 |
else: |
|
|
381 |
fold_idcs = [config['fold_idx'], ] |
|
|
382 |
for g in config['fold_idx_ignore']: |
|
|
383 |
fold_idcs = [i for i in fold_idcs if i != g] |
|
|
384 |
|
|
|
385 |
# Execute |
|
|
386 |
with torch.no_grad(): |
|
|
387 |
if config['predict_folds']: |
|
|
388 |
predict_folds(config=config, loader=loader_test, fold_idcs=fold_idcs) |
|
|
389 |
|
|
|
390 |
if config['merge_predictions']: |
|
|
391 |
merge_predictions(config=config, source=sources[config['dataset']], |
|
|
392 |
loader=loader_test, dict_fns=dict_fns, |
|
|
393 |
save_plots=config['save_plots'], remove_foldw=False, |
|
|
394 |
convert_to_nifti=True) |
|
|
395 |
|
|
|
396 |
|
|
|
397 |
if __name__ == '__main__': |
|
|
398 |
main() |