[96354c]: / src / test / predict.py

Download this file

61 lines (41 with data), 2.1 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
from typing import Tuple
import torch
import numpy as np
from src.dataset.utils.nifi_volume import save_segmask_as_nifi_volume
from src.dataset.patient import Patient
from src.logging_conf import logger
def enable_dropout(model):
for m in model.modules():
if m.__class__.__name__.startswith('Dropout'):
m.train()
def predict(model, images: np.ndarray, device: torch.device, monte_carlo: bool = False) -> Tuple[np.ndarray, np.ndarray]:
model.eval()
if monte_carlo:
enable_dropout(model)
with torch.no_grad():
images = torch.from_numpy(images).unsqueeze(0)
inputs = images.float().to(device)
four_channel_output, prediction_scores = model(inputs)
return four_channel_output.detach().cpu(), prediction_scores.detach().cpu()
def get_prediction_map(four_channel_prediction: torch.tensor) -> np.ndarray:
assert len(four_channel_prediction.shape) == 5
return np.asarray(four_channel_prediction[0].max(0)[1].byte().cpu().data)
def get_scores_map_from_vector(vector_prediction_scores: np.ndarray, path_size: list) -> np.ndarray:
assert len(vector_prediction_scores.shape) == 2, "Must be a 2d array with: (all_voxels, n_labels)"
best_score, _ = vector_prediction_scores.max(1)
return best_score.view(path_size)
def save_predictions(patient: Patient, results: dict, model_path: str, task: str):
output_dir = os.path.join(model_path, task)
output_dir_entropy = os.path.join(output_dir, "entropy")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if not os.path.exists(output_dir_entropy):
os.makedirs(output_dir_entropy)
for name, volume in results.items():
file_name = f"{patient.patch_name}.nii.gz" if name == "prediction" else f"{patient.patch_name}_unc_{name}.nii.gz"
directory = output_dir_entropy if "entropy" in file_name else output_dir
output_path = os.path.join(directory, file_name)
affine_func = patient.get_affine()
logger.info(f"Saving to: {output_path}")
save_segmask_as_nifi_volume(volume, affine_func, output_path)