[bb7f56]: / utils / preprocessing / generate_patches.py

Download this file

124 lines (91 with data), 4.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np
import SimpleITK as sitk
import random
def generate_negative_patches(studies):
art_arr = sitk.GetArrayFromImage(studies[0])
ven_arr = sitk.GetArrayFromImage(studies[1])
noncon_arr = sitk.GetArrayFromImage(studies[2])
mask_arr = sitk.GetArrayFromImage(studies[3])
tot_num_neg = 1
neg_patch = 0
#### Choose where to randomly select a negative patch using
#### Gaussian dist. with peak and std according to
#### blocks of 32 voxels. For example, 16 blocks in 512 voxels.
rand_xy = np.around(np.random.normal(4,2,size=(2,1000)))
rand_z = np.around(np.random.normal(10,4,1000))
rand_xyz = np.vstack((rand_xy, rand_z))
i = 0
chosen_xyz = list()
output_patches = list()
##Create Random Negative Patches
while neg_patch<tot_num_neg:
bad_patch_flag=False
i = i+1
print (i)
## If no patch is found after 80 attempts, abort
if i == 80:
return
## Randomly selected blocks of 32 converted to voxels
x_start = int(rand_xyz[0][i]*32)
y_start = int(rand_xyz[1][i]*32)
z_start = int(rand_xyz[2][i]*32)
# Choose 256^3 patch
x_end = x_start+256
y_end = y_start+256
z_end = z_start+256
print ("Random patch: ")
print (x_start,x_end,y_start,y_end,z_start,z_end)
patch_noncon = noncon_arr[z_start:z_end,y_start:y_end,x_start:x_end]
patch_ven = ven_arr[z_start:z_end,y_start:y_end,x_start:x_end]
patch_art = art_arr[z_start:z_end,y_start:y_end,x_start:x_end]
patch_mask = mask_arr[z_start:z_end,y_start:y_end,x_start:x_end]
if [rand_xyz[0][i],rand_xyz[1][i],rand_xyz[2][i]] in chosen_xyz or rand_xyz[0][i]<0 \
or rand_xyz[0][i]>11 or rand_xyz[1][i]>11 or rand_xyz[1][i]<0 or rand_xyz[2][i]>15 or rand_xyz[2][i]<0:
print ("Exclude Random Patch --> Outside FOV")
continue
##Determine if patch has a low mean value, std or excessive 0 values
for arr_patch in [patch_noncon,patch_ven,patch_art]:
if len(np.where(arr_patch==0)[0]) > 0.3*(256**3) or arr_patch.mean()<-450:
print ("Try new patch, likely off target given stats")
print (len(np.where(arr_patch==0)[0])/(256**3))
print (arr_patch.mean())
print (arr_patch.std())
bad_patch_flag=True
if bad_patch_flag == True or patch_mask.mean()>0:
print ("Rejected 'cause of flag or patch mask mean greater than 0")
print (patch_mask.mean()>0)
continue
print ([x_start,x_end,y_start,y_end,z_start,z_end])
output_patches.append([x_start,x_end,y_start,y_end,z_start,z_end])
chosen_xyz.append([rand_xyz[0][i],rand_xyz[1][i],rand_xyz[2][i]])
print ("Identified a neg_patch number : "+str(neg_patch))
neg_patch = neg_patch + 1
#sitk.WriteImage(sitk.GetImageFromArray(arr_patch),os.path.join(OUT_DIR,str(i)+".nii.gz"))
return output_patches
def generate_positive_patches(studies,patch_size,random_center_displacement):
q = random_center_displacement
pos_patch = list()
mask_arr = sitk.GetArrayFromImage(studies[3])
where = np.where(mask_arr==1)
z_mid = int(where[0].mean())
y_mid = int(where[1].mean())
x_mid = int(where[2].mean())
print (z_mid,y_mid,x_mid)
z_start = random.randint(-(q),q)+z_mid-(patch_size[0]/2)
y_start = random.randint(-(q),q)+y_mid-(patch_size[1]/2)
x_start = random.randint(-(q),q)+x_mid-(patch_size[2]/2)
z_end = z_start+patch_size[0]
y_end = y_start+patch_size[1]
x_end = x_start+patch_size[2]
numbers = [ int(x) for x in [x_start,x_end,y_start,y_end,z_start,z_end] ]
pos_patch.append(numbers)
return pos_patch
def test_onbody_patch(load_path):
try:
arr = sitk.GetArrayFromImage(sitk.ReadImage(load_path))
print (load_path)
print (np.mean(arr))
# print (np.max(arr))
# print (np.min(arr))
except Exception as e:
print ("Could not convert because "+str(e))