|
a |
|
b/src/run_post_processing.py |
|
|
1 |
import os |
|
|
2 |
from typing import Tuple |
|
|
3 |
|
|
|
4 |
from src.compute_metric_results import compute_wt_tc_et |
|
|
5 |
from src.dataset import brats_labels |
|
|
6 |
from src.dataset.utils.nifi_volume import load_nifi_volume_return_nib, save_segmask_as_nifi_volume |
|
|
7 |
from src.post_processing import post_process |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
def load_volume(path) -> Tuple: |
|
|
11 |
return load_nifi_volume_return_nib(path, normalize=False) |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
def compute_metrics(ground_truth_path, subject, clean_segmentation): |
|
|
15 |
gt_path = os.path.join(ground_truth_path, subject, f"{subject}_seg.nii.gz") |
|
|
16 |
data_path = os.path.join(ground_truth_path, subject, f"{subject}_flair.nii.gz") |
|
|
17 |
|
|
|
18 |
volume_gt, _ = load_volume(gt_path) |
|
|
19 |
volume, _ = load_volume(data_path) |
|
|
20 |
|
|
|
21 |
metrics_after = compute_wt_tc_et(clean_segmentation, volume_gt, volume) |
|
|
22 |
|
|
|
23 |
print(f"{subject},After {metrics_after}") |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
if __name__ == "__main__": |
|
|
27 |
setx = "train" |
|
|
28 |
th = 1 |
|
|
29 |
model_id = "model_1598640035" |
|
|
30 |
model_path = f"/mnt/gpid07/users/laura.mora/results/checkpoints/{model_id}/" |
|
|
31 |
|
|
|
32 |
input_dir = os.path.join(model_path, f"segmentation_task", setx) |
|
|
33 |
output_dir = os.path.join(model_path, "segmentation_task_clean_keep_one_two_wt", setx) |
|
|
34 |
if not os.path.exists(output_dir): |
|
|
35 |
os.makedirs(output_dir) |
|
|
36 |
|
|
|
37 |
file_list = sorted([file for file in os.listdir(input_dir) if "BraTS20" in file]) |
|
|
38 |
idx = int(os.environ.get("SLURM_ARRAY_TASK_ID")) if os.environ.get("SLURM_ARRAY_TASK_ID") else 6 |
|
|
39 |
|
|
|
40 |
filename = file_list[idx] |
|
|
41 |
|
|
|
42 |
subject = filename.split(".")[0] |
|
|
43 |
output_path = os.path.join(output_dir, f"{subject}.nii.gz") |
|
|
44 |
prediction_path = os.path.join(input_dir, f"{subject}.nii.gz") |
|
|
45 |
segmentation, segmentation_nib = load_volume(prediction_path) |
|
|
46 |
segmentation_post = segmentation.copy() |
|
|
47 |
|
|
|
48 |
print("Post processing") |
|
|
49 |
|
|
|
50 |
# Keep ONE OR TWO WT |
|
|
51 |
pred_mask_wt = brats_labels.get_wt(segmentation_post) |
|
|
52 |
mask_removed_regions_wt = post_process.keep_conn_component_bigger_than_th(pred_mask_wt, th=th) |
|
|
53 |
elements_to_remove = pred_mask_wt - mask_removed_regions_wt |
|
|
54 |
segmentation_post[elements_to_remove == 1] = 0 |
|
|
55 |
|
|
|
56 |
if setx == "train": |
|
|
57 |
print("Computing metrics..") |
|
|
58 |
ground_truth_path = f"/mnt/gpid07/users/laura.mora/datasets/2020/{setx}/no_patch" |
|
|
59 |
compute_metrics(ground_truth_path, subject, segmentation_post) |
|
|
60 |
|
|
|
61 |
affine_func = segmentation_nib.affine |
|
|
62 |
save_segmask_as_nifi_volume(segmentation_post, affine_func, output_path) |
|
|
63 |
print("Result Saved!") |