[1cac92]: / Region / region_crop.py

Download this file

164 lines (143 with data), 6.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 1 17:14:27 2018
@author: Josefine
"""
import numpy as np
import re
import nibabel as nib
import glob
from skimage.transform import resize
def natural_sort(l):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
return sorted(l, key = alphanum_key)
# Create original high res data function:
def create_data(filename_img):
a = nib.load(filename_img)
a = a.get_data()
a2 = np.clip(a,-1000,1000)
a3 = np.interp(a2, (a2.min(), a2.max()), (-1, +1))
# Reshape:
img = np.zeros([512,512,512])+np.min(a3)
index1 = int(np.ceil((512-a.shape[2])/2))
index2 = int(512-np.floor((512-a.shape[2])/2))
img[:,:,index1:index2] = a3
images = img.transpose((2,0,1))
return images
def create_label(filename_label):
# Label creation
b = nib.load(filename_label)
b = b.get_data()
img = np.zeros([b.shape[0],b.shape[0],b.shape[0]])
index1 = int(np.ceil((img.shape[2]-b.shape[2])/2))
index2 = int(img.shape[2]-np.floor((img.shape[2]-b.shape[2])/2))
img[:,:,index1:index2] = b
labels = img.transpose((2,0,1))
return labels
# Fusion of low resolution probablity maps
def fusion(prob_maps_axial, prob_maps_sag, prob_maps_cor):
# Reshape sagittal data to match axial:
sag_to_axial = prob_maps_sag.transpose((2, 0, 1, 3))
# Reshape coronal data to match axial:
cor_to_sag = prob_maps_cor.transpose((1, 0, 2, 3))
cor_to_axial = cor_to_sag.transpose((2, 0, 1, 3))
temp = np.maximum.reduce([sag_to_axial,cor_to_axial,prob_maps_axial])
return temp
# Region retraction
def cut_region(volumen1):
for i in range(volumen1.shape[0]):
if np.max(volumen1[i,:,:]) == 1:
break
for j in range(volumen1.shape[1]):
if np.max(volumen1[:,j,:]) == 1:
break
for k in range(volumen1.shape[2]):
if np.max(volumen1[:,:,k]) == 1:
break
for i2 in reversed(range(volumen1.shape[0])):
if np.max(volumen1[i2,:,:]) == 1:
break
for j2 in reversed(range(volumen1.shape[1])):
if np.max(volumen1[:,j2,:]) == 1:
break
for k2 in reversed(range(volumen1.shape[2])):
if np.max(volumen1[:,:,k2]) == 1:
break
#factor = int(np.ceil(0.02*volumen1.shape[0]))
#cut_volumen = volumen1[i-factor:i2+factor,j-factor:j2+factor,k-factor:k2+factor]
return i,i2,j,j2,k,k2
# Load data:
filelist_test = natural_sort(glob.glob('WHS/ct_train_test/ct_test/*_image.nii.gz')) # list of file names
filelist_train = natural_sort(glob.glob('WHS/Augment_data/*_image.nii')) # list of file names
filelist_train_label = natural_sort(glob.glob('WHS/Augment_data/*_label.nii')) # list of file names
# Load test data:
files_p0_axial = natural_sort(glob.glob('WHS/Results/Predictions/region/test_prob_maps_axial_*.npz')) # list of file names
files_p0_sag = natural_sort(glob.glob('WHS/Results/Predictions/region/test_prob_maps_sag_*.npz')) # list of file names
files_p0_cor = natural_sort(glob.glob('WHS/Results/Predictions/region/test_prob_maps_cor_*.npz')) # list of file names
## Load train data:
files_p1_axial = natural_sort(glob.glob('WHS/Results/Predictions/region/train_prob_maps_axial_*.npz')) # list of file names
files_p1_sag = natural_sort(glob.glob('WHS/Results/Predictions/region/train_prob_maps_sag_*.npz')) # list of file names
files_p1_cor = natural_sort(glob.glob('WHS/Results/Predictions/region/train_prob_maps_cor_*.npz')) # list of file names
#for n in range(len(files_p0_axial)):
# axial_data = np.load(files_p0_axial[n])
# prob_maps_axial = axial_data['prob_maps']
# sag_data = np.load(files_p0_sag[n])
# prob_maps_sag = sag_data['prob_maps']
# cor_data = np.load(files_p0_cor[n])
# prob_maps_cor = cor_data['prob_maps']
#
# # Create fused propability map
# fused_prob_maps = fusion(prob_maps_axial, prob_maps_sag, prob_maps_cor)
# full_prob_maps = np.zeros([512,512,512,2])
# for i in range(2):
# full_prob_maps[:,:,:,i] = resize(fused_prob_maps[:,:,:,i],(512,512,512))
# label = full_prob_maps.argmax(axis=-1)
# image = create_data(filelist_test[n])
#
# # Get bounding box
# i,i2,j,j2,k,k2 = cut_region(label)
# # Load original data
# factor =int(np.ceil(0.02*image.shape[0]))
# start = int(np.floor(np.min([i,j,k])-factor))
# end = int(np.ceil(np.max([i2,j2,k2])+factor))
# cut = [start,end]
# if cut[0] < 0:
# cut[0] = 0
# if cut[1] > image.shape[0]:
# cut[1] = image.shape[0]
# # Crop bounding box of original data
# cut_img = image[cut[0]:cut[1],cut[0]:cut[1],cut[0]:cut[1]]
# np.savez('WHS/Data/test_segments_{}'.format(n),images=cut_img,cut=cut)
# print('Test image', (n+1), 'cut', (cut))
for n in range(len(files_p1_axial)):
axial_data = np.load(files_p1_axial[n])
prob_maps_axial = axial_data['prob_maps']
sag_data = np.load(files_p1_sag[n])
prob_maps_sag = sag_data['prob_maps']
cor_data = np.load(files_p1_cor[n])
prob_maps_cor = cor_data['prob_maps']
# Create fused propability map
fused_prob_maps = fusion(prob_maps_axial, prob_maps_sag, prob_maps_cor)
labels = fused_prob_maps.argmax(axis=-1)
image = create_data(filelist_train[n])
groundtruth = create_label(filelist_train_label[n])
# Get bounding box
i,i2,j,j2,k,k2 = cut_region(labels)
# Load original data
factor =int(np.ceil(0.02*groundtruth.shape[0]))
mult_factor = image.shape[0]/labels.shape[0]
start = int(np.floor(np.min([i,j,k])*mult_factor-factor))
end = int(np.ceil(np.max([i2,j2,k2])*mult_factor+factor))
cut = [start,end]
if cut[0] < 0:
cut[0] = 0
if cut[1] > image.shape[0]:
cut[1] = image.shape[0]
# Crop bounding box of original data
cut_GT = groundtruth[cut[0]:cut[1],cut[0]:cut[1],cut[0]:cut[1]]
cut_GT = np.round(cut_GT)
cut_img = image[cut[0]:cut[1],cut[0]:cut[1],cut[0]:cut[1]]
np.savez('WHS/Data/train_segments_{}'.format(n),images=cut_img,labels=cut_GT,cut=cut)
print('Train image', (n+1), 'cut', (cut))