a b/unet.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Tue Jul 21 16:11:54 2020
4
5
@author: Billy
6
"""
7
import tensorflow as tf
8
from PIL import Image
9
import numpy as np
10
import math 
11
import matplotlib.pyplot as plt
12
import gc
13
import os
14
import time
15
import random
16
Image.MAX_IMAGE_PIXELS = 933120000
17
18
class uNet_segmentor:
19
    
20
    #Generates UNet Model using the model weights (supply the file location in the initialising line)
21
    def __init__(self, checkpoint_loc, first_conv = 8, window = 224, outputs=11):
22
        
23
        self.n = n = first_conv
24
        self.input_size = input_size = (window,window,3)
25
        self.outputs = outputs
26
        inputs = tf.keras.layers.Input(input_size)
27
        conv1 = tf.keras.layers.Conv2D(n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
28
        conv1 = tf.keras.layers.Conv2D(n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
29
        pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1)
30
        
31
        conv2 = tf.keras.layers.Conv2D(n*2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
32
        conv2 = tf.keras.layers.Conv2D(n*2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
33
        pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2)
34
        
35
        conv3 = tf.keras.layers.Conv2D(n*(2**2), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
36
        conv3 = tf.keras.layers.Conv2D(n*(2**2), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
37
        drop3 = tf.keras.layers.Dropout(0.5)(conv3)
38
        pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(drop3)
39
        
40
        conv4 = tf.keras.layers.Conv2D(n*(2**3), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
41
        conv4 = tf.keras.layers.Conv2D(n*(2**3), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
42
        drop4 = tf.keras.layers.Dropout(0.5)(conv4)
43
        pool4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(drop4)
44
        
45
        conv_e = tf.keras.layers.Conv2D(n*(2**4), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
46
        conv_e = tf.keras.layers.Conv2D(n*(2**4), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv_e)
47
        drop_e = tf.keras.layers.Dropout(0.5)(conv_e)
48
        pool_e = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(drop_e)
49
        
50
        conv5 = tf.keras.layers.Conv2D(n*(2**5), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool_e)
51
        conv5 = tf.keras.layers.Conv2D(n*(2**5), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
52
        drop5 = tf.keras.layers.Dropout(0.5)(conv5)
53
        
54
        up6_e = tf.keras.layers.Conv2D(n*(2**4), 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(tf.keras.layers.UpSampling2D(size = (2,2))(drop5))
55
        merge6_e = tf.keras.layers.concatenate([drop_e,up6_e], axis = 3)
56
        conv6_e = tf.keras.layers.Conv2D(n*(2**4), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6_e)
57
        conv6_e = tf.keras.layers.Conv2D(n*(2**4), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6_e)
58
        
59
        up6 = tf.keras.layers.Conv2D(n*(2**3), 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(tf.keras.layers.UpSampling2D(size = (2,2))(conv6_e))
60
        merge6 = tf.keras.layers.concatenate([drop4,up6], axis = 3)
61
        conv6 = tf.keras.layers.Conv2D(n*(2**3), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
62
        conv6 = tf.keras.layers.Conv2D(n*(2**3), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
63
        
64
        up7 = tf.keras.layers.Conv2D(n*(2**2), 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(tf.keras.layers.UpSampling2D(size = (2,2))(conv6))
65
        merge7 = tf.keras.layers.concatenate([conv3,up7], axis = 3)
66
        conv7 = tf.keras.layers.Conv2D(n*(2**2), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
67
        conv7 = tf.keras.layers.Conv2D(n*(2**2), 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
68
        
69
        up8 = tf.keras.layers.Conv2D(n*2, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(tf.keras.layers.UpSampling2D(size = (2,2))(conv7))
70
        merge8 = tf.keras.layers.concatenate([conv2,up8], axis = 3)
71
        conv8 = tf.keras.layers.Conv2D(n*2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
72
        conv8 = tf.keras.layers.Conv2D(n*2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
73
            
74
        up9 = tf.keras.layers.Conv2D(n*2, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(tf.keras.layers.UpSampling2D(size = (2,2))(conv8))
75
        merge9 = tf.keras.layers.concatenate([conv1,up9], axis = 3)
76
        conv9 = tf.keras.layers.Conv2D(n*2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
77
        conv9 = tf.keras.layers.Conv2D(n*2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
78
        conv9 = tf.keras.layers.Conv2D(self.outputs, 3, activation = 'sigmoid', padding = 'same', kernel_initializer = 'he_normal')(conv9)
79
        
80
        model = tf.keras.Model(inputs=inputs, outputs=conv9)
81
        
82
        model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001), loss = 'binary_crossentropy', metrics = ['accuracy'])
83
        
84
        #If checkpoint path exists, this section loads the checkpoint, updating the model to the latest trained set of weights
85
        parentdir,file = os.path.split(checkpoint_loc)
86
        cpkt_save = os.path.join(parentdir,'checkpoint')
87
        if os.path.exists(cpkt_save):
88
            model.load_weights(checkpoint_loc)
89
        
90
        self.model = model
91
        
92
        
93
        
94
        
95
    #Converts an input image into a set of tiles, according to the Neural Networks input size.
96
    #
97
    #For example, if the neural network input had an input space of 4x4x3 (4x4 RGB), and the input image to the tile function was of size 7x7x3,
98
    #the tile function would insert the 7x7x3 input image into an 8x8x3 template (because 8x8 is the nearest set of dimension that is divisible by 4x4 tiles)
99
    #and then divide that modified template (the 8x8x3 section with the insert image) into 4 4x4x3 sections. 
100
    #
101
    #This is useful, because those 4 4x4x3 sections can then be fed directly into the neural network
102
    def tile(self, im_loc):
103
        height,width,channels = self.input_size
104
        
105
        im = np.asarray(Image.open(im_loc))[:,:,0:3]
106
        newy = math.ceil(im.shape[0]/self.input_size[0])
107
        newx = math.ceil(im.shape[1]/self.input_size[1])
108
        
109
110
        template = np.ones((newy*height,newx*width,channels))*255
111
        template[0:im.shape[0],0:im.shape[1]] = im
112
113
        template = template.reshape(newy, height,newx, width,3).swapaxes(1,2).reshape(-1,height,width,3)
114
115
        return template, im.shape[0], im.shape[1]
116
    
117
    
118
    
119
    #predict method receives an input image file location, it opens that image and then allows the AI to analyse it,
120
    #before returning the AI output.
121
    #
122
    #predict method uses the tile function to split the input image into a set of analysable windows, before analysing
123
    #them, stitching the outputs of the AI together and saving them.
124
    #
125
    #the default save location for this function is in te same folder as the input image, where the output file name is identical to the input,
126
    # except for it being prefixed by 'predicted_'
127
    def predict(self, im_loc, save_loc=None, show = False):
128
        
129
        if save_loc==None:
130
            parent_dir, file = os.path.split(im_loc)
131
            save_loc = os.path.join(parent_dir,'predicted_'+file)
132
        tiles, og_width, og_height = self.tile(im_loc)
133
        print(np.shape(tiles))
134
        pred_tiles = self.model.predict(np.array(tiles),verbose=1)
135
        print(np.shape(pred_tiles))
136
        
137
        splitx = math.ceil(og_width/self.input_size[0])
138
        splity = math.ceil(og_height/self.input_size[1])
139
        
140
        pred_tiles = pred_tiles.reshape(splitx,splity,self.input_size[0],self.input_size[1],self.outputs).swapaxes(1,2).reshape(splitx*self.input_size[0],splity*self.input_size[1],self.outputs)
141
        pred_tiles = np.argmax(pred_tiles,axis=2)
142
        if show:
143
            plt.imshow(pred_tiles,cmap='gray', vmax=self.outputs-1, vmin=0).write_png(save_loc)
144
        else:
145
            plt.imshow(pred_tiles,cmap='gray', vmax=self.outputs-1, vmin=0).write_png(save_loc)
146
            plt.close()
147
        
148
        
149
        
150
    # This method is designed to train the Neural network based on a set of MATLAB app images, but this is not necessary.
151
    #
152
    # The MATLAB app is called 'Image Labeller'
153
    #
154
    # The MATLAB app allows one to assign semantic labels to an RGB/Greyscale image, by annotating pixel regions with certain colours. Each colour
155
    # corresponds to a number, which corresponds to a pixel label. Therefore, for each input image, one can generate a 'ground truth' pixel map, where 
156
    # every pixel is labelled according to the semantic labels set up in the MATLAB App.
157
    #
158
    # In reality, this app is not needed for this function. Just have two folders of images; one with with a set of RGB input images and another
159
    # with a set of ground truths, where the pixel's semantic labels correspond to distinct integers, starting at 0 and incrementing.
160
    # I.e. In the ground truth, 0 = background, 1 = background, 2 = Stroma etc... The ground truth and corresponding image should be of identical
161
    # size and shape. Training images and their corresponding ground truths do not need to be the same size as the Neural network input.
162
    #
163
    # This function receives the folder lcoations of the training images and ground truths, and chooses random windows to learn from within each corresponding set of images.
164
    # one can vary how many windows are processed per batch (batch = 32 by default), the number of photos per training image (photos_per_mask = 100 by default) and how many times each
165
    # training image is revisted (epochs = 500 by default)
166
    def learn_from_matlab(self,image_loc, label_loc,photo_per_mask=100,batch = 32, epochs=500,colour=True):
167
168
        checkpoint_save = "N:\\"+str(self.n)+"_"+str(self.input_size)+"\\cp.ckpt"
169
        
170
        assert os.path.exists(image_loc)
171
        assert os.path.exists(label_loc)
172
        
173
        
174
        def gen(photo_per_mask = photo_per_mask): 
175
            image_files = os.listdir(image_loc)
176
            label_files = os.listdir(label_loc)
177
            gc.collect()
178
            
179
            assert len(label_files)== len(image_files)
180
            for i in range(len(image_files)):
181
                in_ = np.asarray(Image.open(os.path.join(image_loc,image_files[i])).convert('RGB'))
182
                label = np.asarray(Image.open(os.path.join(label_loc,label_files[i])).convert('L'))
183
                
184
                if colour:
185
                    in_ = np.asarray(Image.open(os.path.join(image_loc,image_files[i])).convert('RGB'))
186
                    label = np.asarray(Image.open(os.path.join(label_loc,label_files[i])).convert('L'))
187
                    padded_image_data = [0]*3
188
                    for j in range(3):
189
                        padded_image_data[j] = np.pad(in_[:,:,j], (int(self.input_size[0]/2),int(self.input_size[1]/2)), 'symmetric')
190
                    padded_image_data = np.dstack((padded_image_data[0],padded_image_data[1],padded_image_data[2]))
191
                    padded_label_data = np.pad(label, (int(self.input_size[0]/2),int(self.input_size[1]/2)), 'symmetric')
192
                else:
193
                    in_ = np.asarray(Image.open(os.path.join(image_loc,image_files[i])).convert('L'))
194
                    label = np.asarray(Image.open(os.path.join(label_loc,label_files[i])).convert('L'))
195
                    padded_image_data = np.pad(in_,(int(self.input_size[0]/2),int(self.input_size[1]/2)), 'symmetric')
196
                    padded_label_data = np.pad(label, (int(self.input_size[0]/2),int(self.input_size[1]/2)), 'symmetric')
197
                
198
                for j in range(photo_per_mask):
199
                    
200
                    height, width = label.shape
201
                    row = random.randint(0,height-1)
202
                    col = random.randint(0,width-1)
203
                    
204
                    rotate = random.randint(0,2)
205
                    
206
                    window = np.rot90(padded_image_data[row:row+self.input_size[0],col:col+self.input_size[1]],rotate)
207
                    mask = np.asarray(np.rot90(padded_label_data[row:row+self.input_size[0],col:col+self.input_size[1]],rotate))
208
                    
209
                    if random.uniform(0,1)>0.5:
210
                        window = np.flip(window,axis=0)
211
                        mask = np.flip(mask, axis=0)
212
                    
213
                    yield (window, (np.arange(self.outputs) == mask[...,None]).astype(int))
214
        
215
        shuffle = photo_per_mask*len(os.listdir(image_loc))
216
        dataset = tf.data.Dataset.from_generator(gen, (tf.int64, tf.int64), (self.input_size, (self.input_size[0],self.input_size[1],self.outputs)))
217
        dataset = dataset.shuffle(shuffle).batch(batch)
218
        
219
        cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save,
220
                                                         save_weights_only=True,
221
                                                         verbose=1)
222
        self.model.fit(dataset,epochs=epochs,callbacks=[cp_callback])
223
224
        
225
if __name__ == '__main__':
226
    #initialise the unet with the file location of the model checkpoint
227
    unet = uNet_segmentor('N:\\8_(384, 384, 3)\\cp.ckpt',window = 384, first_conv=8)
228
    
229
    #insert the location of input images and files here
230
    image_loc = "N:\\Bill_Mcgough\\Correct Labelling\\Labels\\ImageData"
231
    label_loc = "N:\\Bill_Mcgough\\Correct Labelling\\Labels\\MatlabLabelData"
232
    
233
    #unet training here based on location of images and files
234
    unet.learn_from_matlab(image_loc, label_loc, photo_per_mask=65, batch = 10, epochs = 5000, colour = True)
235
    
236
    predictions = "C:\\Users\\Billy\\Downloads\\Prepared_SVS"
237
    images = os.listdir(predictions)
238
    print(images)
239
    
240
    for image in images:
241
        
242
        complete_loc = os.path.join(predictions,image)
243
        print(complete_loc)
244
        unet.predict(complete_loc)
245
    
246
    before = time.time()
247
    unet.predict("C:\\Users\\Billy\\Downloads\\Output tester\\14-2302 (1)_0_mag20.png")
248
    end = time.time()
249
    print(end-before)