|
a |
|
b/unet.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
""" |
|
|
3 |
Created on Tue Jul 21 16:11:54 2020 |
|
|
4 |
|
|
|
5 |
@author: Billy |
|
|
6 |
""" |
|
|
7 |
import tensorflow as tf |
|
|
8 |
from PIL import Image |
|
|
9 |
import numpy as np |
|
|
10 |
import math |
|
|
11 |
import matplotlib.pyplot as plt |
|
|
12 |
import gc |
|
|
13 |
import os |
|
|
14 |
import time |
|
|
15 |
import random |
|
|
16 |
Image.MAX_IMAGE_PIXELS = 933120000 |
|
|
17 |
|
|
|
18 |
class uNet_segmentor: |
|
|
19 |
|
|
|
20 |
#Generates UNet Model using the model weights (supply the file location in the initialising line) |
|
|
21 |
def __init__(self, checkpoint_loc, first_conv = 8, window = 224, outputs=11): |
|
|
22 |
|
|
|
23 |
self.n = n = first_conv |
|
|
24 |
self.input_size = input_size = (window,window,3) |
|
|
25 |
self.outputs = outputs |
|
|
26 |
inputs = tf.keras.layers.Input(input_size) |
|
|
27 |
conv1 = tf.keras.layers.Conv2D(n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs) |
|
|
28 |
conv1 = tf.keras.layers.Conv2D(n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1) |
|
|
29 |
pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1) |
|
|
30 |
|
|
|
31 |
conv2 = tf.keras.layers.Conv2D(n*2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1) |
|
|
32 |
conv2 = tf.keras.layers.Conv2D(n*2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2) |
|
|
33 |
pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2) |
|
|
34 |
|
|
|
35 |
conv3 = tf.keras.layers.Conv2D(n*(2**2), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2) |
|
|
36 |
conv3 = tf.keras.layers.Conv2D(n*(2**2), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3) |
|
|
37 |
drop3 = tf.keras.layers.Dropout(0.5)(conv3) |
|
|
38 |
pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(drop3) |
|
|
39 |
|
|
|
40 |
conv4 = tf.keras.layers.Conv2D(n*(2**3), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3) |
|
|
41 |
conv4 = tf.keras.layers.Conv2D(n*(2**3), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4) |
|
|
42 |
drop4 = tf.keras.layers.Dropout(0.5)(conv4) |
|
|
43 |
pool4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(drop4) |
|
|
44 |
|
|
|
45 |
conv_e = tf.keras.layers.Conv2D(n*(2**4), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4) |
|
|
46 |
conv_e = tf.keras.layers.Conv2D(n*(2**4), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv_e) |
|
|
47 |
drop_e = tf.keras.layers.Dropout(0.5)(conv_e) |
|
|
48 |
pool_e = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(drop_e) |
|
|
49 |
|
|
|
50 |
conv5 = tf.keras.layers.Conv2D(n*(2**5), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool_e) |
|
|
51 |
conv5 = tf.keras.layers.Conv2D(n*(2**5), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5) |
|
|
52 |
drop5 = tf.keras.layers.Dropout(0.5)(conv5) |
|
|
53 |
|
|
|
54 |
up6_e = tf.keras.layers.Conv2D(n*(2**4), 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(tf.keras.layers.UpSampling2D(size = (2,2))(drop5)) |
|
|
55 |
merge6_e = tf.keras.layers.concatenate([drop_e,up6_e], axis = 3) |
|
|
56 |
conv6_e = tf.keras.layers.Conv2D(n*(2**4), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6_e) |
|
|
57 |
conv6_e = tf.keras.layers.Conv2D(n*(2**4), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6_e) |
|
|
58 |
|
|
|
59 |
up6 = tf.keras.layers.Conv2D(n*(2**3), 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(tf.keras.layers.UpSampling2D(size = (2,2))(conv6_e)) |
|
|
60 |
merge6 = tf.keras.layers.concatenate([drop4,up6], axis = 3) |
|
|
61 |
conv6 = tf.keras.layers.Conv2D(n*(2**3), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6) |
|
|
62 |
conv6 = tf.keras.layers.Conv2D(n*(2**3), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) |
|
|
63 |
|
|
|
64 |
up7 = tf.keras.layers.Conv2D(n*(2**2), 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(tf.keras.layers.UpSampling2D(size = (2,2))(conv6)) |
|
|
65 |
merge7 = tf.keras.layers.concatenate([conv3,up7], axis = 3) |
|
|
66 |
conv7 = tf.keras.layers.Conv2D(n*(2**2), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7) |
|
|
67 |
conv7 = tf.keras.layers.Conv2D(n*(2**2), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7) |
|
|
68 |
|
|
|
69 |
up8 = tf.keras.layers.Conv2D(n*2, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(tf.keras.layers.UpSampling2D(size = (2,2))(conv7)) |
|
|
70 |
merge8 = tf.keras.layers.concatenate([conv2,up8], axis = 3) |
|
|
71 |
conv8 = tf.keras.layers.Conv2D(n*2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8) |
|
|
72 |
conv8 = tf.keras.layers.Conv2D(n*2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8) |
|
|
73 |
|
|
|
74 |
up9 = tf.keras.layers.Conv2D(n*2, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(tf.keras.layers.UpSampling2D(size = (2,2))(conv8)) |
|
|
75 |
merge9 = tf.keras.layers.concatenate([conv1,up9], axis = 3) |
|
|
76 |
conv9 = tf.keras.layers.Conv2D(n*2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9) |
|
|
77 |
conv9 = tf.keras.layers.Conv2D(n*2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) |
|
|
78 |
conv9 = tf.keras.layers.Conv2D(self.outputs, 3, activation = 'sigmoid', padding = 'same', kernel_initializer = 'he_normal')(conv9) |
|
|
79 |
|
|
|
80 |
model = tf.keras.Model(inputs=inputs, outputs=conv9) |
|
|
81 |
|
|
|
82 |
model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001), loss = 'binary_crossentropy', metrics = ['accuracy']) |
|
|
83 |
|
|
|
84 |
#If checkpoint path exists, this section loads the checkpoint, updating the model to the latest trained set of weights |
|
|
85 |
parentdir,file = os.path.split(checkpoint_loc) |
|
|
86 |
cpkt_save = os.path.join(parentdir,'checkpoint') |
|
|
87 |
if os.path.exists(cpkt_save): |
|
|
88 |
model.load_weights(checkpoint_loc) |
|
|
89 |
|
|
|
90 |
self.model = model |
|
|
91 |
|
|
|
92 |
|
|
|
93 |
|
|
|
94 |
|
|
|
95 |
#Converts an input image into a set of tiles, according to the Neural Networks input size. |
|
|
96 |
# |
|
|
97 |
#For example, if the neural network input had an input space of 4x4x3 (4x4 RGB), and the input image to the tile function was of size 7x7x3, |
|
|
98 |
#the tile function would insert the 7x7x3 input image into an 8x8x3 template (because 8x8 is the nearest set of dimension that is divisible by 4x4 tiles) |
|
|
99 |
#and then divide that modified template (the 8x8x3 section with the insert image) into 4 4x4x3 sections. |
|
|
100 |
# |
|
|
101 |
#This is useful, because those 4 4x4x3 sections can then be fed directly into the neural network |
|
|
102 |
def tile(self, im_loc): |
|
|
103 |
height,width,channels = self.input_size |
|
|
104 |
|
|
|
105 |
im = np.asarray(Image.open(im_loc))[:,:,0:3] |
|
|
106 |
newy = math.ceil(im.shape[0]/self.input_size[0]) |
|
|
107 |
newx = math.ceil(im.shape[1]/self.input_size[1]) |
|
|
108 |
|
|
|
109 |
|
|
|
110 |
template = np.ones((newy*height,newx*width,channels))*255 |
|
|
111 |
template[0:im.shape[0],0:im.shape[1]] = im |
|
|
112 |
|
|
|
113 |
template = template.reshape(newy, height,newx, width,3).swapaxes(1,2).reshape(-1,height,width,3) |
|
|
114 |
|
|
|
115 |
return template, im.shape[0], im.shape[1] |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
|
|
|
119 |
#predict method receives an input image file location, it opens that image and then allows the AI to analyse it, |
|
|
120 |
#before returning the AI output. |
|
|
121 |
# |
|
|
122 |
#predict method uses the tile function to split the input image into a set of analysable windows, before analysing |
|
|
123 |
#them, stitching the outputs of the AI together and saving them. |
|
|
124 |
# |
|
|
125 |
#the default save location for this function is in te same folder as the input image, where the output file name is identical to the input, |
|
|
126 |
# except for it being prefixed by 'predicted_' |
|
|
127 |
def predict(self, im_loc, save_loc=None, show = False): |
|
|
128 |
|
|
|
129 |
if save_loc==None: |
|
|
130 |
parent_dir, file = os.path.split(im_loc) |
|
|
131 |
save_loc = os.path.join(parent_dir,'predicted_'+file) |
|
|
132 |
tiles, og_width, og_height = self.tile(im_loc) |
|
|
133 |
print(np.shape(tiles)) |
|
|
134 |
pred_tiles = self.model.predict(np.array(tiles),verbose=1) |
|
|
135 |
print(np.shape(pred_tiles)) |
|
|
136 |
|
|
|
137 |
splitx = math.ceil(og_width/self.input_size[0]) |
|
|
138 |
splity = math.ceil(og_height/self.input_size[1]) |
|
|
139 |
|
|
|
140 |
pred_tiles = pred_tiles.reshape(splitx,splity,self.input_size[0],self.input_size[1],self.outputs).swapaxes(1,2).reshape(splitx*self.input_size[0],splity*self.input_size[1],self.outputs) |
|
|
141 |
pred_tiles = np.argmax(pred_tiles,axis=2) |
|
|
142 |
if show: |
|
|
143 |
plt.imshow(pred_tiles,cmap='gray', vmax=self.outputs-1, vmin=0).write_png(save_loc) |
|
|
144 |
else: |
|
|
145 |
plt.imshow(pred_tiles,cmap='gray', vmax=self.outputs-1, vmin=0).write_png(save_loc) |
|
|
146 |
plt.close() |
|
|
147 |
|
|
|
148 |
|
|
|
149 |
|
|
|
150 |
# This method is designed to train the Neural network based on a set of MATLAB app images, but this is not necessary. |
|
|
151 |
# |
|
|
152 |
# The MATLAB app is called 'Image Labeller' |
|
|
153 |
# |
|
|
154 |
# The MATLAB app allows one to assign semantic labels to an RGB/Greyscale image, by annotating pixel regions with certain colours. Each colour |
|
|
155 |
# corresponds to a number, which corresponds to a pixel label. Therefore, for each input image, one can generate a 'ground truth' pixel map, where |
|
|
156 |
# every pixel is labelled according to the semantic labels set up in the MATLAB App. |
|
|
157 |
# |
|
|
158 |
# In reality, this app is not needed for this function. Just have two folders of images; one with with a set of RGB input images and another |
|
|
159 |
# with a set of ground truths, where the pixel's semantic labels correspond to distinct integers, starting at 0 and incrementing. |
|
|
160 |
# I.e. In the ground truth, 0 = background, 1 = background, 2 = Stroma etc... The ground truth and corresponding image should be of identical |
|
|
161 |
# size and shape. Training images and their corresponding ground truths do not need to be the same size as the Neural network input. |
|
|
162 |
# |
|
|
163 |
# This function receives the folder lcoations of the training images and ground truths, and chooses random windows to learn from within each corresponding set of images. |
|
|
164 |
# one can vary how many windows are processed per batch (batch = 32 by default), the number of photos per training image (photos_per_mask = 100 by default) and how many times each |
|
|
165 |
# training image is revisted (epochs = 500 by default) |
|
|
166 |
def learn_from_matlab(self,image_loc, label_loc,photo_per_mask=100,batch = 32, epochs=500,colour=True): |
|
|
167 |
|
|
|
168 |
checkpoint_save = "N:\\"+str(self.n)+"_"+str(self.input_size)+"\\cp.ckpt" |
|
|
169 |
|
|
|
170 |
assert os.path.exists(image_loc) |
|
|
171 |
assert os.path.exists(label_loc) |
|
|
172 |
|
|
|
173 |
|
|
|
174 |
def gen(photo_per_mask = photo_per_mask): |
|
|
175 |
image_files = os.listdir(image_loc) |
|
|
176 |
label_files = os.listdir(label_loc) |
|
|
177 |
gc.collect() |
|
|
178 |
|
|
|
179 |
assert len(label_files)== len(image_files) |
|
|
180 |
for i in range(len(image_files)): |
|
|
181 |
in_ = np.asarray(Image.open(os.path.join(image_loc,image_files[i])).convert('RGB')) |
|
|
182 |
label = np.asarray(Image.open(os.path.join(label_loc,label_files[i])).convert('L')) |
|
|
183 |
|
|
|
184 |
if colour: |
|
|
185 |
in_ = np.asarray(Image.open(os.path.join(image_loc,image_files[i])).convert('RGB')) |
|
|
186 |
label = np.asarray(Image.open(os.path.join(label_loc,label_files[i])).convert('L')) |
|
|
187 |
padded_image_data = [0]*3 |
|
|
188 |
for j in range(3): |
|
|
189 |
padded_image_data[j] = np.pad(in_[:,:,j], (int(self.input_size[0]/2),int(self.input_size[1]/2)), 'symmetric') |
|
|
190 |
padded_image_data = np.dstack((padded_image_data[0],padded_image_data[1],padded_image_data[2])) |
|
|
191 |
padded_label_data = np.pad(label, (int(self.input_size[0]/2),int(self.input_size[1]/2)), 'symmetric') |
|
|
192 |
else: |
|
|
193 |
in_ = np.asarray(Image.open(os.path.join(image_loc,image_files[i])).convert('L')) |
|
|
194 |
label = np.asarray(Image.open(os.path.join(label_loc,label_files[i])).convert('L')) |
|
|
195 |
padded_image_data = np.pad(in_,(int(self.input_size[0]/2),int(self.input_size[1]/2)), 'symmetric') |
|
|
196 |
padded_label_data = np.pad(label, (int(self.input_size[0]/2),int(self.input_size[1]/2)), 'symmetric') |
|
|
197 |
|
|
|
198 |
for j in range(photo_per_mask): |
|
|
199 |
|
|
|
200 |
height, width = label.shape |
|
|
201 |
row = random.randint(0,height-1) |
|
|
202 |
col = random.randint(0,width-1) |
|
|
203 |
|
|
|
204 |
rotate = random.randint(0,2) |
|
|
205 |
|
|
|
206 |
window = np.rot90(padded_image_data[row:row+self.input_size[0],col:col+self.input_size[1]],rotate) |
|
|
207 |
mask = np.asarray(np.rot90(padded_label_data[row:row+self.input_size[0],col:col+self.input_size[1]],rotate)) |
|
|
208 |
|
|
|
209 |
if random.uniform(0,1)>0.5: |
|
|
210 |
window = np.flip(window,axis=0) |
|
|
211 |
mask = np.flip(mask, axis=0) |
|
|
212 |
|
|
|
213 |
yield (window, (np.arange(self.outputs) == mask[...,None]).astype(int)) |
|
|
214 |
|
|
|
215 |
shuffle = photo_per_mask*len(os.listdir(image_loc)) |
|
|
216 |
dataset = tf.data.Dataset.from_generator(gen, (tf.int64, tf.int64), (self.input_size, (self.input_size[0],self.input_size[1],self.outputs))) |
|
|
217 |
dataset = dataset.shuffle(shuffle).batch(batch) |
|
|
218 |
|
|
|
219 |
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save, |
|
|
220 |
save_weights_only=True, |
|
|
221 |
verbose=1) |
|
|
222 |
self.model.fit(dataset,epochs=epochs,callbacks=[cp_callback]) |
|
|
223 |
|
|
|
224 |
|
|
|
225 |
if __name__ == '__main__': |
|
|
226 |
#initialise the unet with the file location of the model checkpoint |
|
|
227 |
unet = uNet_segmentor('N:\\8_(384, 384, 3)\\cp.ckpt',window = 384, first_conv=8) |
|
|
228 |
|
|
|
229 |
#insert the location of input images and files here |
|
|
230 |
image_loc = "N:\\Bill_Mcgough\\Correct Labelling\\Labels\\ImageData" |
|
|
231 |
label_loc = "N:\\Bill_Mcgough\\Correct Labelling\\Labels\\MatlabLabelData" |
|
|
232 |
|
|
|
233 |
#unet training here based on location of images and files |
|
|
234 |
unet.learn_from_matlab(image_loc, label_loc, photo_per_mask=65, batch = 10, epochs = 5000, colour = True) |
|
|
235 |
|
|
|
236 |
predictions = "C:\\Users\\Billy\\Downloads\\Prepared_SVS" |
|
|
237 |
images = os.listdir(predictions) |
|
|
238 |
print(images) |
|
|
239 |
|
|
|
240 |
for image in images: |
|
|
241 |
|
|
|
242 |
complete_loc = os.path.join(predictions,image) |
|
|
243 |
print(complete_loc) |
|
|
244 |
unet.predict(complete_loc) |
|
|
245 |
|
|
|
246 |
before = time.time() |
|
|
247 |
unet.predict("C:\\Users\\Billy\\Downloads\\Output tester\\14-2302 (1)_0_mag20.png") |
|
|
248 |
end = time.time() |
|
|
249 |
print(end-before) |