|
a |
|
b/procedures/trainer.py |
|
|
1 |
# MIT License |
|
|
2 |
# |
|
|
3 |
# Copyright (c) 2019 Yisroel Mirsky |
|
|
4 |
# |
|
|
5 |
# Permission is hereby granted, free of charge, to any person obtaining a copy |
|
|
6 |
# of this software and associated documentation files (the "Software"), to deal |
|
|
7 |
# in the Software without restriction, including without limitation the rights |
|
|
8 |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
|
9 |
# copies of the Software, and to permit persons to whom the Software is |
|
|
10 |
# furnished to do so, subject to the following conditions: |
|
|
11 |
# |
|
|
12 |
# The above copyright notice and this permission notice shall be included in all |
|
|
13 |
# copies or substantial portions of the Software. |
|
|
14 |
# |
|
|
15 |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
|
16 |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
|
17 |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
|
18 |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
|
19 |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
|
20 |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
|
21 |
# SOFTWARE. |
|
|
22 |
|
|
|
23 |
from __future__ import print_function, division |
|
|
24 |
from config import * # user configuration in config.py |
|
|
25 |
import os |
|
|
26 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
|
27 |
os.environ["CUDA_VISIBLE_DEVICES"] = config['gpus'] |
|
|
28 |
|
|
|
29 |
from utils.dataloader import DataLoader |
|
|
30 |
from keras.layers import Input, Dropout, Concatenate, Cropping3D |
|
|
31 |
from keras.layers import BatchNormalization |
|
|
32 |
from keras.layers.advanced_activations import LeakyReLU |
|
|
33 |
from keras.layers.convolutional import UpSampling3D, Conv3D |
|
|
34 |
from keras.models import Model |
|
|
35 |
from keras.optimizers import Adam |
|
|
36 |
import matplotlib.pyplot as plt |
|
|
37 |
import datetime |
|
|
38 |
import numpy as np |
|
|
39 |
|
|
|
40 |
import tensorflow as tf |
|
|
41 |
import keras.backend.tensorflow_backend as ktf |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def get_session(): |
|
|
45 |
gpu_options = tf.GPUOptions(allow_growth=True) |
|
|
46 |
return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) |
|
|
47 |
|
|
|
48 |
|
|
|
49 |
ktf.set_session(get_session()) |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
class Trainer: |
|
|
53 |
def __init__(self, isInjector=True): |
|
|
54 |
self.isInjector = isInjector |
|
|
55 |
# Input shape |
|
|
56 |
cube_shape = config['cube_shape'] |
|
|
57 |
self.img_rows = config['cube_shape'][1] |
|
|
58 |
self.img_cols = config['cube_shape'][2] |
|
|
59 |
self.img_depth = config['cube_shape'][0] |
|
|
60 |
self.channels = 1 |
|
|
61 |
self.num_classes = 5 |
|
|
62 |
self.img_shape = (self.img_rows, self.img_cols, self.img_depth, self.channels) |
|
|
63 |
|
|
|
64 |
# Configure data loader |
|
|
65 |
if self.isInjector: |
|
|
66 |
self.dataset_path = config['unhealthy_samples'] |
|
|
67 |
self.modelpath = config['modelpath_inject'] |
|
|
68 |
else: |
|
|
69 |
self.dataset_path = config['healthy_samples'] |
|
|
70 |
self.modelpath = config['modelpath_remove'] |
|
|
71 |
|
|
|
72 |
self.dataloader = DataLoader(dataset_path=self.dataset_path, normdata_path=self.modelpath, |
|
|
73 |
img_res=(self.img_rows, self.img_cols, self.img_depth)) |
|
|
74 |
|
|
|
75 |
# Calculate output shape of D (PatchGAN) |
|
|
76 |
patch = int(self.img_rows / 2 ** 4) |
|
|
77 |
self.disc_patch = (patch, patch, patch, 1) |
|
|
78 |
|
|
|
79 |
# Number of filters in the first layer of G and D |
|
|
80 |
self.gf = 100 |
|
|
81 |
self.df = 100 |
|
|
82 |
|
|
|
83 |
optimizer = Adam(0.0002, 0.5) |
|
|
84 |
optimizer_G = Adam(0.000001, 0.5) |
|
|
85 |
|
|
|
86 |
# Build and compile the discriminator |
|
|
87 |
self.discriminator = self.build_discriminator() |
|
|
88 |
self.discriminator.summary() |
|
|
89 |
self.discriminator.compile(loss='mse', |
|
|
90 |
optimizer=optimizer_G, |
|
|
91 |
metrics=['accuracy']) |
|
|
92 |
|
|
|
93 |
# ------------------------- |
|
|
94 |
# Construct Computational |
|
|
95 |
# Graph of Generator |
|
|
96 |
# ------------------------- |
|
|
97 |
|
|
|
98 |
# Build the generator |
|
|
99 |
self.generator = self.build_generator() |
|
|
100 |
self.generator.summary() |
|
|
101 |
|
|
|
102 |
# Input images and their conditioning images |
|
|
103 |
img_A = Input(shape=self.img_shape) |
|
|
104 |
img_B = Input(shape=self.img_shape) |
|
|
105 |
|
|
|
106 |
# By conditioning on B generate a fake version of A |
|
|
107 |
fake_A = self.generator([img_B]) |
|
|
108 |
|
|
|
109 |
# For the combined model we will only train the generator |
|
|
110 |
self.discriminator.trainable = False |
|
|
111 |
|
|
|
112 |
# Discriminators determines validity of translated images / condition pairs |
|
|
113 |
valid = self.discriminator([fake_A, img_B]) |
|
|
114 |
|
|
|
115 |
self.combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A]) |
|
|
116 |
self.combined.compile(loss=['mse', 'mae'], |
|
|
117 |
loss_weights=[1, 100], |
|
|
118 |
optimizer=optimizer) |
|
|
119 |
|
|
|
120 |
def build_generator(self): |
|
|
121 |
"""U-Net Generator""" |
|
|
122 |
|
|
|
123 |
def get_crop_shape(target, refer): |
|
|
124 |
|
|
|
125 |
# depth, the 4rth dimension |
|
|
126 |
cd = (target.get_shape()[3] - refer.get_shape()[3]).value |
|
|
127 |
assert (cd >= 0) |
|
|
128 |
if cd % 2 != 0: |
|
|
129 |
cd1, cd2 = int(cd / 2), int(cd / 2) + 1 |
|
|
130 |
else: |
|
|
131 |
cd1, cd2 = int(cd / 2), int(cd / 2) |
|
|
132 |
# width, the 3rd dimension |
|
|
133 |
cw = (target.get_shape()[2] - refer.get_shape()[2]).value |
|
|
134 |
assert (cw >= 0) |
|
|
135 |
if cw % 2 != 0: |
|
|
136 |
cw1, cw2 = int(cw / 2), int(cw / 2) + 1 |
|
|
137 |
else: |
|
|
138 |
cw1, cw2 = int(cw / 2), int(cw / 2) |
|
|
139 |
# height, the 2nd dimension |
|
|
140 |
ch = (target.get_shape()[1] - refer.get_shape()[1]).value |
|
|
141 |
assert (ch >= 0) |
|
|
142 |
if ch % 2 != 0: |
|
|
143 |
ch1, ch2 = int(ch / 2), int(ch / 2) + 1 |
|
|
144 |
else: |
|
|
145 |
ch1, ch2 = int(ch / 2), int(ch / 2) |
|
|
146 |
|
|
|
147 |
return (ch1, ch2), (cw1, cw2), (cd1, cd2) |
|
|
148 |
|
|
|
149 |
def conv3d(layer_input, filters, f_size=4, bn=True): |
|
|
150 |
"""Layers used during downsampling""" |
|
|
151 |
d = Conv3D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input) |
|
|
152 |
d = LeakyReLU(alpha=0.2)(d) |
|
|
153 |
if bn: |
|
|
154 |
d = BatchNormalization(momentum=0.8)(d) |
|
|
155 |
return d |
|
|
156 |
|
|
|
157 |
def deconv3d(layer_input, skip_input, filters, f_size=4, dropout_rate=0.5): |
|
|
158 |
"""Layers used during upsampling""" |
|
|
159 |
u = UpSampling3D(size=2)(layer_input) |
|
|
160 |
u = Conv3D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u) |
|
|
161 |
if dropout_rate: |
|
|
162 |
u = Dropout(dropout_rate)(u) |
|
|
163 |
u = BatchNormalization(momentum=0.8)(u) |
|
|
164 |
|
|
|
165 |
# u = Concatenate()([u, skip_input]) |
|
|
166 |
ch, cw, cd = get_crop_shape(u, skip_input) |
|
|
167 |
crop_conv4 = Cropping3D(cropping=(ch, cw, cd), data_format="channels_last")(u) |
|
|
168 |
u = Concatenate()([crop_conv4, skip_input]) |
|
|
169 |
return u |
|
|
170 |
|
|
|
171 |
# Image input |
|
|
172 |
d0 = Input(shape=self.img_shape, name="input_image") |
|
|
173 |
|
|
|
174 |
# Downsampling |
|
|
175 |
d1 = conv3d(d0, self.gf, bn=False) |
|
|
176 |
d2 = conv3d(d1, self.gf * 2) |
|
|
177 |
d3 = conv3d(d2, self.gf * 4) |
|
|
178 |
d4 = conv3d(d3, self.gf * 8) |
|
|
179 |
d5 = conv3d(d4, self.gf * 8) |
|
|
180 |
u3 = deconv3d(d5, d4, self.gf * 8) |
|
|
181 |
u4 = deconv3d(u3, d3, self.gf * 4) |
|
|
182 |
u5 = deconv3d(u4, d2, self.gf * 2) |
|
|
183 |
u6 = deconv3d(u5, d1, self.gf) |
|
|
184 |
|
|
|
185 |
u7 = UpSampling3D(size=2)(u6) |
|
|
186 |
output_img = Conv3D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7) |
|
|
187 |
|
|
|
188 |
return Model(inputs=[d0], outputs=[output_img]) |
|
|
189 |
|
|
|
190 |
def build_discriminator(self): |
|
|
191 |
|
|
|
192 |
def d_layer(layer_input, filters, f_size=4, bn=True): |
|
|
193 |
"""Discriminator layer""" |
|
|
194 |
d = Conv3D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input) |
|
|
195 |
d = LeakyReLU(alpha=0.2)(d) |
|
|
196 |
if bn: |
|
|
197 |
d = BatchNormalization(momentum=0.8)(d) |
|
|
198 |
return d |
|
|
199 |
|
|
|
200 |
img_A = Input(shape=self.img_shape) |
|
|
201 |
img_B = Input(shape=self.img_shape) |
|
|
202 |
|
|
|
203 |
# Concatenate image and conditioning image by channels to produce input |
|
|
204 |
model_input = Concatenate(axis=-1)([img_A, img_B]) |
|
|
205 |
|
|
|
206 |
d1 = d_layer(model_input, self.df, bn=False) |
|
|
207 |
d2 = d_layer(d1, self.df * 2) |
|
|
208 |
d3 = d_layer(d2, self.df * 4) |
|
|
209 |
d4 = d_layer(d3, self.df * 8) |
|
|
210 |
|
|
|
211 |
validity = Conv3D(1, kernel_size=4, strides=1, padding='same')(d4) |
|
|
212 |
|
|
|
213 |
return Model([img_A, img_B], validity) |
|
|
214 |
|
|
|
215 |
def train(self, epochs, batch_size=1, sample_interval=50): |
|
|
216 |
start_time = datetime.datetime.now() |
|
|
217 |
# Adversarial loss ground truths |
|
|
218 |
valid = np.zeros((batch_size,) + self.disc_patch) |
|
|
219 |
fake = np.ones((batch_size,) + self.disc_patch) |
|
|
220 |
|
|
|
221 |
for epoch in range(epochs): |
|
|
222 |
# save model |
|
|
223 |
if epoch > 0: |
|
|
224 |
print("Saving Models...") |
|
|
225 |
self.generator.save(os.path.join(self.modelpath, "G_model.h5")) # creates a HDF5 file |
|
|
226 |
self.discriminator.save( |
|
|
227 |
os.path.join(self.modelpath, "D_model.h5")) # creates a HDF5 file 'my_model.h5' |
|
|
228 |
|
|
|
229 |
for batch_i, (imgs_A, imgs_B) in enumerate(self.dataloader.load_batch(batch_size)): |
|
|
230 |
# --------------------- |
|
|
231 |
# Train Discriminator |
|
|
232 |
# --------------------- |
|
|
233 |
# Condition on B and generate a translated version |
|
|
234 |
fake_A = self.generator.predict([imgs_B]) |
|
|
235 |
|
|
|
236 |
# Train the discriminators (original images = real / generated = Fake) |
|
|
237 |
d_loss_real = self.discriminator.train_on_batch([imgs_A, imgs_B], valid) |
|
|
238 |
d_loss_fake = self.discriminator.train_on_batch([fake_A, imgs_B], fake) |
|
|
239 |
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) |
|
|
240 |
|
|
|
241 |
# ----------------- |
|
|
242 |
# Train Generator |
|
|
243 |
# ----------------- |
|
|
244 |
|
|
|
245 |
# Train the generators |
|
|
246 |
g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A]) |
|
|
247 |
elapsed_time = datetime.datetime.now() - start_time |
|
|
248 |
# Plot the progress |
|
|
249 |
print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f] time: %s" % (epoch, epochs, |
|
|
250 |
batch_i, |
|
|
251 |
self.dataloader.n_batches, |
|
|
252 |
d_loss[0], |
|
|
253 |
100 * d_loss[1], |
|
|
254 |
g_loss[0], |
|
|
255 |
elapsed_time)) |
|
|
256 |
|
|
|
257 |
# If at save interval => save generated image samples |
|
|
258 |
if batch_i % sample_interval == 0: |
|
|
259 |
self.show_progress(epoch, batch_i) |
|
|
260 |
|
|
|
261 |
def show_progress(self, epoch, batch_i): |
|
|
262 |
filename = "%d_%d.png" % (epoch, batch_i) |
|
|
263 |
if self.isInjector: |
|
|
264 |
savepath = os.path.join(config['progress'], "injector") |
|
|
265 |
else: |
|
|
266 |
savepath = os.path.join(config['progress'], "remover") |
|
|
267 |
os.makedirs(savepath, exist_ok=True) |
|
|
268 |
r, c = 3, 3 |
|
|
269 |
|
|
|
270 |
imgs_A, imgs_B = self.dataloader.load_data(batch_size=3, is_testing=True) |
|
|
271 |
fake_A = self.generator.predict([imgs_B]) |
|
|
272 |
|
|
|
273 |
gen_imgs = np.concatenate([imgs_B, fake_A, imgs_A]) |
|
|
274 |
|
|
|
275 |
# Rescale images 0 - 1 |
|
|
276 |
gen_imgs = 0.5 * gen_imgs + 0.5 |
|
|
277 |
|
|
|
278 |
titles = ['Condition', 'Generated', 'Original'] |
|
|
279 |
fig, axs = plt.subplots(r, c) |
|
|
280 |
cnt = 0 |
|
|
281 |
for i in range(r): |
|
|
282 |
for j in range(c): |
|
|
283 |
axs[i, j].imshow(gen_imgs[cnt].reshape((self.img_depth, self.img_rows, self.img_cols))[int(self.img_depth/2), :, :]) |
|
|
284 |
axs[i, j].set_title(titles[i]) |
|
|
285 |
axs[i, j].axis('off') |
|
|
286 |
cnt += 1 |
|
|
287 |
fig.savefig(os.path.join(savepath, filename)) |
|
|
288 |
plt.close() |
|
|
289 |
|