Diff of /procedures/trainer.py [000000] .. [f84ece]

Switch to unified view

a b/procedures/trainer.py
1
# MIT License
2
# 
3
# Copyright (c) 2019 Yisroel Mirsky
4
# 
5
# Permission is hereby granted, free of charge, to any person obtaining a copy
6
# of this software and associated documentation files (the "Software"), to deal
7
# in the Software without restriction, including without limitation the rights
8
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
# copies of the Software, and to permit persons to whom the Software is
10
# furnished to do so, subject to the following conditions:
11
# 
12
# The above copyright notice and this permission notice shall be included in all
13
# copies or substantial portions of the Software.
14
# 
15
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
# SOFTWARE.
22
23
from __future__ import print_function, division
24
from config import *  # user configuration in config.py
25
import os
26
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
27
os.environ["CUDA_VISIBLE_DEVICES"] = config['gpus']
28
29
from utils.dataloader import DataLoader
30
from keras.layers import Input, Dropout, Concatenate, Cropping3D
31
from keras.layers import BatchNormalization
32
from keras.layers.advanced_activations import LeakyReLU
33
from keras.layers.convolutional import UpSampling3D, Conv3D
34
from keras.models import Model
35
from keras.optimizers import Adam
36
import matplotlib.pyplot as plt
37
import datetime
38
import numpy as np
39
40
import tensorflow as tf
41
import keras.backend.tensorflow_backend as ktf
42
43
44
def get_session():
45
    gpu_options = tf.GPUOptions(allow_growth=True)
46
    return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
47
48
49
ktf.set_session(get_session())
50
51
52
class Trainer:
53
    def __init__(self, isInjector=True):
54
        self.isInjector = isInjector
55
        # Input shape
56
        cube_shape = config['cube_shape']
57
        self.img_rows = config['cube_shape'][1]
58
        self.img_cols = config['cube_shape'][2]
59
        self.img_depth = config['cube_shape'][0]
60
        self.channels = 1
61
        self.num_classes = 5
62
        self.img_shape = (self.img_rows, self.img_cols, self.img_depth, self.channels)
63
64
        # Configure data loader
65
        if self.isInjector:
66
            self.dataset_path = config['unhealthy_samples']
67
            self.modelpath = config['modelpath_inject']
68
        else:
69
            self.dataset_path = config['healthy_samples']
70
            self.modelpath = config['modelpath_remove']
71
72
        self.dataloader = DataLoader(dataset_path=self.dataset_path, normdata_path=self.modelpath,
73
                                     img_res=(self.img_rows, self.img_cols, self.img_depth))
74
75
        # Calculate output shape of D (PatchGAN)
76
        patch = int(self.img_rows / 2 ** 4)
77
        self.disc_patch = (patch, patch, patch, 1)
78
79
        # Number of filters in the first layer of G and D
80
        self.gf = 100
81
        self.df = 100
82
83
        optimizer = Adam(0.0002, 0.5)
84
        optimizer_G = Adam(0.000001, 0.5)
85
86
        # Build and compile the discriminator
87
        self.discriminator = self.build_discriminator()
88
        self.discriminator.summary()
89
        self.discriminator.compile(loss='mse',
90
                                   optimizer=optimizer_G,
91
                                   metrics=['accuracy'])
92
93
        # -------------------------
94
        # Construct Computational
95
        #   Graph of Generator
96
        # -------------------------
97
98
        # Build the generator
99
        self.generator = self.build_generator()
100
        self.generator.summary()
101
102
        # Input images and their conditioning images
103
        img_A = Input(shape=self.img_shape)
104
        img_B = Input(shape=self.img_shape)
105
106
        # By conditioning on B generate a fake version of A
107
        fake_A = self.generator([img_B])
108
109
        # For the combined model we will only train the generator
110
        self.discriminator.trainable = False
111
112
        # Discriminators determines validity of translated images / condition pairs
113
        valid = self.discriminator([fake_A, img_B])
114
115
        self.combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A])
116
        self.combined.compile(loss=['mse', 'mae'],
117
                              loss_weights=[1, 100],
118
                              optimizer=optimizer)
119
120
    def build_generator(self):
121
        """U-Net Generator"""
122
123
        def get_crop_shape(target, refer):
124
125
            # depth, the 4rth dimension
126
            cd = (target.get_shape()[3] - refer.get_shape()[3]).value
127
            assert (cd >= 0)
128
            if cd % 2 != 0:
129
                cd1, cd2 = int(cd / 2), int(cd / 2) + 1
130
            else:
131
                cd1, cd2 = int(cd / 2), int(cd / 2)
132
            # width, the 3rd dimension
133
            cw = (target.get_shape()[2] - refer.get_shape()[2]).value
134
            assert (cw >= 0)
135
            if cw % 2 != 0:
136
                cw1, cw2 = int(cw / 2), int(cw / 2) + 1
137
            else:
138
                cw1, cw2 = int(cw / 2), int(cw / 2)
139
            # height, the 2nd dimension
140
            ch = (target.get_shape()[1] - refer.get_shape()[1]).value
141
            assert (ch >= 0)
142
            if ch % 2 != 0:
143
                ch1, ch2 = int(ch / 2), int(ch / 2) + 1
144
            else:
145
                ch1, ch2 = int(ch / 2), int(ch / 2)
146
147
            return (ch1, ch2), (cw1, cw2), (cd1, cd2)
148
149
        def conv3d(layer_input, filters, f_size=4, bn=True):
150
            """Layers used during downsampling"""
151
            d = Conv3D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
152
            d = LeakyReLU(alpha=0.2)(d)
153
            if bn:
154
                d = BatchNormalization(momentum=0.8)(d)
155
            return d
156
157
        def deconv3d(layer_input, skip_input, filters, f_size=4, dropout_rate=0.5):
158
            """Layers used during upsampling"""
159
            u = UpSampling3D(size=2)(layer_input)
160
            u = Conv3D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
161
            if dropout_rate:
162
                u = Dropout(dropout_rate)(u)
163
            u = BatchNormalization(momentum=0.8)(u)
164
165
            # u = Concatenate()([u, skip_input])
166
            ch, cw, cd = get_crop_shape(u, skip_input)
167
            crop_conv4 = Cropping3D(cropping=(ch, cw, cd), data_format="channels_last")(u)
168
            u = Concatenate()([crop_conv4, skip_input])
169
            return u
170
171
        # Image input
172
        d0 = Input(shape=self.img_shape, name="input_image")
173
174
        # Downsampling
175
        d1 = conv3d(d0, self.gf, bn=False)
176
        d2 = conv3d(d1, self.gf * 2)
177
        d3 = conv3d(d2, self.gf * 4)
178
        d4 = conv3d(d3, self.gf * 8)
179
        d5 = conv3d(d4, self.gf * 8)
180
        u3 = deconv3d(d5, d4, self.gf * 8)
181
        u4 = deconv3d(u3, d3, self.gf * 4)
182
        u5 = deconv3d(u4, d2, self.gf * 2)
183
        u6 = deconv3d(u5, d1, self.gf)
184
185
        u7 = UpSampling3D(size=2)(u6)
186
        output_img = Conv3D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)
187
188
        return Model(inputs=[d0], outputs=[output_img])
189
190
    def build_discriminator(self):
191
192
        def d_layer(layer_input, filters, f_size=4, bn=True):
193
            """Discriminator layer"""
194
            d = Conv3D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
195
            d = LeakyReLU(alpha=0.2)(d)
196
            if bn:
197
                d = BatchNormalization(momentum=0.8)(d)
198
            return d
199
200
        img_A = Input(shape=self.img_shape)
201
        img_B = Input(shape=self.img_shape)
202
203
        # Concatenate image and conditioning image by channels to produce input
204
        model_input = Concatenate(axis=-1)([img_A, img_B])
205
206
        d1 = d_layer(model_input, self.df, bn=False)
207
        d2 = d_layer(d1, self.df * 2)
208
        d3 = d_layer(d2, self.df * 4)
209
        d4 = d_layer(d3, self.df * 8)
210
211
        validity = Conv3D(1, kernel_size=4, strides=1, padding='same')(d4)
212
213
        return Model([img_A, img_B], validity)
214
215
    def train(self, epochs, batch_size=1, sample_interval=50):
216
        start_time = datetime.datetime.now()
217
        # Adversarial loss ground truths
218
        valid = np.zeros((batch_size,) + self.disc_patch)
219
        fake = np.ones((batch_size,) + self.disc_patch)
220
221
        for epoch in range(epochs):
222
            # save model
223
            if epoch > 0:
224
                print("Saving Models...")
225
                self.generator.save(os.path.join(self.modelpath, "G_model.h5"))  # creates a HDF5 file
226
                self.discriminator.save(
227
                    os.path.join(self.modelpath, "D_model.h5"))  # creates a HDF5 file 'my_model.h5'
228
229
            for batch_i, (imgs_A, imgs_B) in enumerate(self.dataloader.load_batch(batch_size)):
230
                # ---------------------
231
                #  Train Discriminator
232
                # ---------------------
233
                # Condition on B and generate a translated version
234
                fake_A = self.generator.predict([imgs_B])
235
236
                # Train the discriminators (original images = real / generated = Fake)
237
                d_loss_real = self.discriminator.train_on_batch([imgs_A, imgs_B], valid)
238
                d_loss_fake = self.discriminator.train_on_batch([fake_A, imgs_B], fake)
239
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
240
241
                # -----------------
242
                #  Train Generator
243
                # -----------------
244
245
                # Train the generators
246
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])
247
                elapsed_time = datetime.datetime.now() - start_time
248
                # Plot the progress
249
                print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f] time: %s" % (epoch, epochs,
250
                                                                                                      batch_i,
251
                                                                                                      self.dataloader.n_batches,
252
                                                                                                      d_loss[0],
253
                                                                                                      100 * d_loss[1],
254
                                                                                                      g_loss[0],
255
                                                                                                      elapsed_time))
256
257
                # If at save interval => save generated image samples
258
                if batch_i % sample_interval == 0:
259
                    self.show_progress(epoch, batch_i)
260
261
    def show_progress(self, epoch, batch_i):
262
        filename = "%d_%d.png" % (epoch, batch_i)
263
        if self.isInjector:
264
            savepath = os.path.join(config['progress'], "injector")
265
        else:
266
            savepath = os.path.join(config['progress'], "remover")
267
        os.makedirs(savepath, exist_ok=True)
268
        r, c = 3, 3
269
270
        imgs_A, imgs_B = self.dataloader.load_data(batch_size=3, is_testing=True)
271
        fake_A = self.generator.predict([imgs_B])
272
273
        gen_imgs = np.concatenate([imgs_B, fake_A, imgs_A])
274
275
        # Rescale images 0 - 1
276
        gen_imgs = 0.5 * gen_imgs + 0.5
277
278
        titles = ['Condition', 'Generated', 'Original']
279
        fig, axs = plt.subplots(r, c)
280
        cnt = 0
281
        for i in range(r):
282
            for j in range(c):
283
                axs[i, j].imshow(gen_imgs[cnt].reshape((self.img_depth, self.img_rows, self.img_cols))[int(self.img_depth/2), :, :])
284
                axs[i, j].set_title(titles[i])
285
                axs[i, j].axis('off')
286
                cnt += 1
287
        fig.savefig(os.path.join(savepath, filename))
288
        plt.close()
289