Diff of /Segmentation/fusion.py [000000] .. [1cac92]

Switch to unified view

a b/Segmentation/fusion.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
"""
4
Created on Tue Nov 20 19:44:39 2018
5
6
@author: Josefine
7
"""
8
9
import numpy as np
10
import re
11
import nibabel as nib
12
import glob
13
from skimage.transform import resize
14
from scipy import ndimage
15
16
def natural_sort(l): 
17
    convert = lambda text: int(text) if text.isdigit() else text.lower() 
18
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
19
    return sorted(l, key = alphanum_key)
20
21
# Create original high res data function:
22
def create_data(filename_img,filename_label):
23
    images = []
24
    a = nib.load(filename_img)
25
    a = a.get_data()
26
    # Normalize:
27
    a2 = np.clip(a,-1000,1000)
28
    a3 = np.interp(a2, (a2.min(), a2.max()), (-1, +1))
29
    # Reshape:
30
    img = np.zeros([512,512,512])+np.min(a3)
31
    index1 = int(np.ceil((512-a.shape[2])/2))
32
    index2 = int(512-np.floor((512-a.shape[2])/2))
33
    img[:,:,index1:index2] = a3
34
    for i in range(img.shape[2]):
35
            images.append((img[:,:,i]))
36
    images = np.asarray(images)
37
38
    return images
39
40
# Fusion of low resolution probablity maps
41
def fusion(prob_maps_axial, prob_maps_sag, prob_maps_cor):
42
    sag_to_axial = []
43
    for i in range(prob_maps_sag.shape[2]):
44
        sag_to_axial.append((prob_maps_sag[:,i,:,:]))  
45
    sag_to_axial = np.asarray(sag_to_axial)
46
    
47
    # Reshape coronal data to match axial:
48
    cor_to_sag = []
49
    for i in range(prob_maps_cor.shape[2]):
50
        cor_to_sag.append((prob_maps_cor[:,i,:,:]))  
51
    cor_to_sag = np.asarray(cor_to_sag)
52
    cor_to_axial = []
53
    for i in range(prob_maps_cor.shape[2]):
54
        cor_to_axial.append((cor_to_sag[:,:,i,:]))  
55
    cor_to_axial = np.asarray(cor_to_axial)
56
    cor_to_axial2 = []
57
    for i in range(prob_maps_cor.shape[2]):
58
        cor_to_axial2.append((cor_to_axial[:,i,:,:]))  
59
    cor_to_axial = np.asarray(cor_to_axial2)
60
    
61
    temp = np.maximum.reduce([sag_to_axial,cor_to_axial,prob_maps_axial])
62
    return temp
63
64
def remove_objects(binary_mask):
65
    labelled_mask, num_labels = ndimage.label(binary_mask)
66
67
    # Let us now remove all the too small regions.
68
    refined_mask = binary_mask.copy()
69
    minimum_cc_sum = 5000
70
    for label in range(num_labels):
71
        if np.sum(refined_mask[labelled_mask == label]) < minimum_cc_sum:
72
            refined_mask[labelled_mask == label] = 0
73
    return refined_mask
74
75
filelist_train = natural_sort(glob.glob('WHS/ct_train_test/ct_test/*_image.nii.gz')) # list of file names
76
cropped_files = natural_sort(glob.glob('WHS/Data/test_segments_*.npz')) # list of file names
77
78
files_axial = natural_sort(glob.glob('WHS/Results/Predictions/segment/train_prob_maps_axial_*.npz')) # list of file names
79
files_sag = natural_sort(glob.glob('WHS/Results/Predictions/segment/train_prob_maps_sag_*.npz')) # list of file names
80
files_cor = natural_sort(glob.glob('WHS/Results/Predictions/segment/train_prob_maps_cor_*.npz')) # list of file names
81
82
for n in range(len(files_axial)):
83
    axial_data = np.load(files_axial[n])
84
    prob_maps_axial = axial_data['prob_maps']
85
    sag_data = np.load(files_sag[n])
86
    prob_maps_sag = sag_data['prob_maps']
87
    cor_data = np.load(files_cor[n])
88
    prob_maps_cor = cor_data['prob_maps']
89
    cut_file = np.load(cropped_files[n])
90
    cut = cut_file['cut']
91
92
    # Create fused propability map
93
    fused_prob_maps = fusion(prob_maps_axial, prob_maps_sag, prob_maps_cor)
94
    side_length = cut[1]-cut[0]
95
    lab = np.zeros([side_length,side_length,side_length,8])
96
    for i in range(8):
97
        lab[:,:,:,i] = resize(fused_prob_maps[:,:,:,i],(side_length,side_length,side_length))
98
    full_labels = np.zeros([512,512,512,8])
99
    full_labels[cut[0]:cut[1],cut[0]:cut[1],cut[0]:cut[1],:] = lab
100
    labels = full_labels.argmax(axis=-1)
101
    print('Test image', (n+1))
102
    np.savez('WHS/Results/Predictions/final/prediction_{}'.format(n),prob_map = labels)