Switch to unified view

a b/src/run_post_processing.py
1
import os
2
from typing import Tuple
3
4
from src.compute_metric_results import compute_wt_tc_et
5
from src.dataset import brats_labels
6
from src.dataset.utils.nifi_volume import load_nifi_volume_return_nib, save_segmask_as_nifi_volume
7
from src.post_processing import post_process
8
9
10
def load_volume(path) -> Tuple:
11
    return load_nifi_volume_return_nib(path, normalize=False)
12
13
14
def compute_metrics(ground_truth_path, subject, clean_segmentation):
15
    gt_path = os.path.join(ground_truth_path, subject, f"{subject}_seg.nii.gz")
16
    data_path = os.path.join(ground_truth_path, subject, f"{subject}_flair.nii.gz")
17
18
    volume_gt, _ = load_volume(gt_path)
19
    volume, _ = load_volume(data_path)
20
21
    metrics_after = compute_wt_tc_et(clean_segmentation, volume_gt, volume)
22
23
    print(f"{subject},After  {metrics_after}")
24
25
26
if __name__ == "__main__":
27
    setx = "train"
28
    th = 1
29
    model_id = "model_1598640035"
30
    model_path = f"/mnt/gpid07/users/laura.mora/results/checkpoints/{model_id}/"
31
32
    input_dir = os.path.join(model_path, f"segmentation_task", setx)
33
    output_dir = os.path.join(model_path, "segmentation_task_clean_keep_one_two_wt", setx)
34
    if not os.path.exists(output_dir):
35
        os.makedirs(output_dir)
36
37
    file_list = sorted([file for file in os.listdir(input_dir) if "BraTS20" in file])
38
    idx = int(os.environ.get("SLURM_ARRAY_TASK_ID")) if os.environ.get("SLURM_ARRAY_TASK_ID") else 6
39
40
    filename = file_list[idx]
41
42
    subject = filename.split(".")[0]
43
    output_path = os.path.join(output_dir, f"{subject}.nii.gz")
44
    prediction_path = os.path.join(input_dir, f"{subject}.nii.gz")
45
    segmentation, segmentation_nib = load_volume(prediction_path)
46
    segmentation_post = segmentation.copy()
47
48
    print("Post processing")
49
50
    # Keep ONE OR TWO WT
51
    pred_mask_wt = brats_labels.get_wt(segmentation_post)
52
    mask_removed_regions_wt = post_process.keep_conn_component_bigger_than_th(pred_mask_wt, th=th)
53
    elements_to_remove = pred_mask_wt - mask_removed_regions_wt
54
    segmentation_post[elements_to_remove == 1] = 0
55
56
    if setx == "train":
57
        print("Computing metrics..")
58
        ground_truth_path = f"/mnt/gpid07/users/laura.mora/datasets/2020/{setx}/no_patch"
59
        compute_metrics(ground_truth_path, subject, segmentation_post)
60
61
    affine_func = segmentation_nib.affine
62
    save_segmask_as_nifi_volume(segmentation_post, affine_func, output_path)
63
    print("Result Saved!")