Diff of /Region/region_crop.py [000000] .. [1cac92]

Switch to unified view

a b/Region/region_crop.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
"""
4
Created on Thu Nov  1 17:14:27 2018
5
6
@author: Josefine
7
"""
8
9
import numpy as np
10
import re
11
import nibabel as nib
12
import glob
13
from skimage.transform import resize
14
15
def natural_sort(l): 
16
    convert = lambda text: int(text) if text.isdigit() else text.lower() 
17
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
18
    return sorted(l, key = alphanum_key)
19
20
# Create original high res data function:
21
def create_data(filename_img):
22
    a = nib.load(filename_img)
23
    a = a.get_data()
24
    a2 = np.clip(a,-1000,1000)
25
    a3 = np.interp(a2, (a2.min(), a2.max()), (-1, +1))
26
    # Reshape:
27
    img = np.zeros([512,512,512])+np.min(a3)
28
    index1 = int(np.ceil((512-a.shape[2])/2))
29
    index2 = int(512-np.floor((512-a.shape[2])/2))
30
    img[:,:,index1:index2] = a3
31
    images = img.transpose((2,0,1))
32
    return images
33
34
def create_label(filename_label):
35
    # Label creation
36
    b = nib.load(filename_label)
37
    b = b.get_data()
38
    img = np.zeros([b.shape[0],b.shape[0],b.shape[0]])
39
    index1 = int(np.ceil((img.shape[2]-b.shape[2])/2))
40
    index2 = int(img.shape[2]-np.floor((img.shape[2]-b.shape[2])/2))
41
    img[:,:,index1:index2] = b
42
    labels = img.transpose((2,0,1))
43
    return labels
44
45
# Fusion of low resolution probablity maps
46
def fusion(prob_maps_axial, prob_maps_sag, prob_maps_cor):
47
    # Reshape sagittal data to match axial:
48
    sag_to_axial = prob_maps_sag.transpose((2, 0, 1, 3))
49
    # Reshape coronal data to match axial:
50
    cor_to_sag = prob_maps_cor.transpose((1, 0, 2, 3))
51
    cor_to_axial = cor_to_sag.transpose((2, 0, 1, 3))
52
    temp = np.maximum.reduce([sag_to_axial,cor_to_axial,prob_maps_axial])
53
    return temp
54
55
# Region retraction
56
def cut_region(volumen1):
57
    for i in range(volumen1.shape[0]):
58
        if np.max(volumen1[i,:,:]) == 1:
59
            break    
60
    
61
    for j in range(volumen1.shape[1]):
62
        if np.max(volumen1[:,j,:]) == 1:
63
            break    
64
        
65
    for k in range(volumen1.shape[2]):
66
        if np.max(volumen1[:,:,k]) == 1:
67
            break
68
        
69
    for i2 in reversed(range(volumen1.shape[0])):
70
        if np.max(volumen1[i2,:,:]) == 1:
71
            break    
72
    
73
    for j2 in reversed(range(volumen1.shape[1])):
74
        if np.max(volumen1[:,j2,:]) == 1:
75
            break    
76
        
77
    for k2 in reversed(range(volumen1.shape[2])):
78
        if np.max(volumen1[:,:,k2]) == 1:
79
            break    
80
    #factor = int(np.ceil(0.02*volumen1.shape[0]))
81
    #cut_volumen = volumen1[i-factor:i2+factor,j-factor:j2+factor,k-factor:k2+factor]
82
    return i,i2,j,j2,k,k2
83
84
# Load data:
85
filelist_test = natural_sort(glob.glob('WHS/ct_train_test/ct_test/*_image.nii.gz')) # list of file names
86
87
filelist_train = natural_sort(glob.glob('WHS/Augment_data/*_image.nii')) # list of file names
88
filelist_train_label = natural_sort(glob.glob('WHS/Augment_data/*_label.nii')) # list of file names
89
90
# Load test data:
91
files_p0_axial = natural_sort(glob.glob('WHS/Results/Predictions/region/test_prob_maps_axial_*.npz')) # list of file names
92
files_p0_sag = natural_sort(glob.glob('WHS/Results/Predictions/region/test_prob_maps_sag_*.npz')) # list of file names
93
files_p0_cor = natural_sort(glob.glob('WHS/Results/Predictions/region/test_prob_maps_cor_*.npz')) # list of file names
94
95
## Load train data:
96
files_p1_axial = natural_sort(glob.glob('WHS/Results/Predictions/region/train_prob_maps_axial_*.npz')) # list of file names
97
files_p1_sag = natural_sort(glob.glob('WHS/Results/Predictions/region/train_prob_maps_sag_*.npz')) # list of file names
98
files_p1_cor = natural_sort(glob.glob('WHS/Results/Predictions/region/train_prob_maps_cor_*.npz')) # list of file names
99
100
#for n in range(len(files_p0_axial)):
101
#    axial_data = np.load(files_p0_axial[n])
102
#    prob_maps_axial = axial_data['prob_maps']
103
#    sag_data = np.load(files_p0_sag[n])
104
#    prob_maps_sag = sag_data['prob_maps']
105
#    cor_data = np.load(files_p0_cor[n])
106
#    prob_maps_cor = cor_data['prob_maps']
107
#
108
#    # Create fused propability map
109
#    fused_prob_maps = fusion(prob_maps_axial, prob_maps_sag, prob_maps_cor)
110
#    full_prob_maps = np.zeros([512,512,512,2])
111
#    for i in range(2):
112
#        full_prob_maps[:,:,:,i] = resize(fused_prob_maps[:,:,:,i],(512,512,512))    
113
#    label = full_prob_maps.argmax(axis=-1)
114
#    image = create_data(filelist_test[n])
115
#
116
#    # Get bounding box
117
#    i,i2,j,j2,k,k2 = cut_region(label)
118
#    # Load original data
119
#    factor =int(np.ceil(0.02*image.shape[0]))
120
#    start = int(np.floor(np.min([i,j,k])-factor))
121
#    end = int(np.ceil(np.max([i2,j2,k2])+factor))
122
#    cut = [start,end]
123
#    if cut[0] < 0:
124
#        cut[0] = 0
125
#    if cut[1] > image.shape[0]:
126
#        cut[1] = image.shape[0]
127
#    # Crop bounding box of original data
128
#    cut_img = image[cut[0]:cut[1],cut[0]:cut[1],cut[0]:cut[1]]
129
#    np.savez('WHS/Data/test_segments_{}'.format(n),images=cut_img,cut=cut)
130
#    print('Test image', (n+1), 'cut', (cut))
131
132
for n in range(len(files_p1_axial)):
133
    axial_data = np.load(files_p1_axial[n])
134
    prob_maps_axial = axial_data['prob_maps']
135
    sag_data = np.load(files_p1_sag[n])
136
    prob_maps_sag = sag_data['prob_maps']
137
    cor_data = np.load(files_p1_cor[n])
138
    prob_maps_cor = cor_data['prob_maps']
139
140
    # Create fused propability map
141
    fused_prob_maps = fusion(prob_maps_axial, prob_maps_sag, prob_maps_cor)
142
    labels = fused_prob_maps.argmax(axis=-1)
143
    image = create_data(filelist_train[n])
144
    groundtruth = create_label(filelist_train_label[n])
145
    # Get bounding box
146
    i,i2,j,j2,k,k2 = cut_region(labels)
147
148
    # Load original data
149
    factor =int(np.ceil(0.02*groundtruth.shape[0]))
150
    mult_factor = image.shape[0]/labels.shape[0]
151
    start = int(np.floor(np.min([i,j,k])*mult_factor-factor))
152
    end = int(np.ceil(np.max([i2,j2,k2])*mult_factor+factor))
153
    cut = [start,end]
154
    if cut[0] < 0:
155
        cut[0] = 0
156
    if cut[1] > image.shape[0]:
157
        cut[1] = image.shape[0]
158
    # Crop bounding box of original data
159
    cut_GT = groundtruth[cut[0]:cut[1],cut[0]:cut[1],cut[0]:cut[1]]
160
    cut_GT = np.round(cut_GT)
161
    cut_img = image[cut[0]:cut[1],cut[0]:cut[1],cut[0]:cut[1]]
162
    np.savez('WHS/Data/train_segments_{}'.format(n),images=cut_img,labels=cut_GT,cut=cut)
163
    print('Train image', (n+1), 'cut', (cut))