--- a
+++ b/procedures/trainer.py
@@ -0,0 +1,289 @@
+# MIT License
+# 
+# Copyright (c) 2019 Yisroel Mirsky
+# 
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+# 
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+# 
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+from __future__ import print_function, division
+from config import *  # user configuration in config.py
+import os
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+os.environ["CUDA_VISIBLE_DEVICES"] = config['gpus']
+
+from utils.dataloader import DataLoader
+from keras.layers import Input, Dropout, Concatenate, Cropping3D
+from keras.layers import BatchNormalization
+from keras.layers.advanced_activations import LeakyReLU
+from keras.layers.convolutional import UpSampling3D, Conv3D
+from keras.models import Model
+from keras.optimizers import Adam
+import matplotlib.pyplot as plt
+import datetime
+import numpy as np
+
+import tensorflow as tf
+import keras.backend.tensorflow_backend as ktf
+
+
+def get_session():
+    gpu_options = tf.GPUOptions(allow_growth=True)
+    return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
+
+
+ktf.set_session(get_session())
+
+
+class Trainer:
+    def __init__(self, isInjector=True):
+        self.isInjector = isInjector
+        # Input shape
+        cube_shape = config['cube_shape']
+        self.img_rows = config['cube_shape'][1]
+        self.img_cols = config['cube_shape'][2]
+        self.img_depth = config['cube_shape'][0]
+        self.channels = 1
+        self.num_classes = 5
+        self.img_shape = (self.img_rows, self.img_cols, self.img_depth, self.channels)
+
+        # Configure data loader
+        if self.isInjector:
+            self.dataset_path = config['unhealthy_samples']
+            self.modelpath = config['modelpath_inject']
+        else:
+            self.dataset_path = config['healthy_samples']
+            self.modelpath = config['modelpath_remove']
+
+        self.dataloader = DataLoader(dataset_path=self.dataset_path, normdata_path=self.modelpath,
+                                     img_res=(self.img_rows, self.img_cols, self.img_depth))
+
+        # Calculate output shape of D (PatchGAN)
+        patch = int(self.img_rows / 2 ** 4)
+        self.disc_patch = (patch, patch, patch, 1)
+
+        # Number of filters in the first layer of G and D
+        self.gf = 100
+        self.df = 100
+
+        optimizer = Adam(0.0002, 0.5)
+        optimizer_G = Adam(0.000001, 0.5)
+
+        # Build and compile the discriminator
+        self.discriminator = self.build_discriminator()
+        self.discriminator.summary()
+        self.discriminator.compile(loss='mse',
+                                   optimizer=optimizer_G,
+                                   metrics=['accuracy'])
+
+        # -------------------------
+        # Construct Computational
+        #   Graph of Generator
+        # -------------------------
+
+        # Build the generator
+        self.generator = self.build_generator()
+        self.generator.summary()
+
+        # Input images and their conditioning images
+        img_A = Input(shape=self.img_shape)
+        img_B = Input(shape=self.img_shape)
+
+        # By conditioning on B generate a fake version of A
+        fake_A = self.generator([img_B])
+
+        # For the combined model we will only train the generator
+        self.discriminator.trainable = False
+
+        # Discriminators determines validity of translated images / condition pairs
+        valid = self.discriminator([fake_A, img_B])
+
+        self.combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A])
+        self.combined.compile(loss=['mse', 'mae'],
+                              loss_weights=[1, 100],
+                              optimizer=optimizer)
+
+    def build_generator(self):
+        """U-Net Generator"""
+
+        def get_crop_shape(target, refer):
+
+            # depth, the 4rth dimension
+            cd = (target.get_shape()[3] - refer.get_shape()[3]).value
+            assert (cd >= 0)
+            if cd % 2 != 0:
+                cd1, cd2 = int(cd / 2), int(cd / 2) + 1
+            else:
+                cd1, cd2 = int(cd / 2), int(cd / 2)
+            # width, the 3rd dimension
+            cw = (target.get_shape()[2] - refer.get_shape()[2]).value
+            assert (cw >= 0)
+            if cw % 2 != 0:
+                cw1, cw2 = int(cw / 2), int(cw / 2) + 1
+            else:
+                cw1, cw2 = int(cw / 2), int(cw / 2)
+            # height, the 2nd dimension
+            ch = (target.get_shape()[1] - refer.get_shape()[1]).value
+            assert (ch >= 0)
+            if ch % 2 != 0:
+                ch1, ch2 = int(ch / 2), int(ch / 2) + 1
+            else:
+                ch1, ch2 = int(ch / 2), int(ch / 2)
+
+            return (ch1, ch2), (cw1, cw2), (cd1, cd2)
+
+        def conv3d(layer_input, filters, f_size=4, bn=True):
+            """Layers used during downsampling"""
+            d = Conv3D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
+            d = LeakyReLU(alpha=0.2)(d)
+            if bn:
+                d = BatchNormalization(momentum=0.8)(d)
+            return d
+
+        def deconv3d(layer_input, skip_input, filters, f_size=4, dropout_rate=0.5):
+            """Layers used during upsampling"""
+            u = UpSampling3D(size=2)(layer_input)
+            u = Conv3D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
+            if dropout_rate:
+                u = Dropout(dropout_rate)(u)
+            u = BatchNormalization(momentum=0.8)(u)
+
+            # u = Concatenate()([u, skip_input])
+            ch, cw, cd = get_crop_shape(u, skip_input)
+            crop_conv4 = Cropping3D(cropping=(ch, cw, cd), data_format="channels_last")(u)
+            u = Concatenate()([crop_conv4, skip_input])
+            return u
+
+        # Image input
+        d0 = Input(shape=self.img_shape, name="input_image")
+
+        # Downsampling
+        d1 = conv3d(d0, self.gf, bn=False)
+        d2 = conv3d(d1, self.gf * 2)
+        d3 = conv3d(d2, self.gf * 4)
+        d4 = conv3d(d3, self.gf * 8)
+        d5 = conv3d(d4, self.gf * 8)
+        u3 = deconv3d(d5, d4, self.gf * 8)
+        u4 = deconv3d(u3, d3, self.gf * 4)
+        u5 = deconv3d(u4, d2, self.gf * 2)
+        u6 = deconv3d(u5, d1, self.gf)
+
+        u7 = UpSampling3D(size=2)(u6)
+        output_img = Conv3D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)
+
+        return Model(inputs=[d0], outputs=[output_img])
+
+    def build_discriminator(self):
+
+        def d_layer(layer_input, filters, f_size=4, bn=True):
+            """Discriminator layer"""
+            d = Conv3D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
+            d = LeakyReLU(alpha=0.2)(d)
+            if bn:
+                d = BatchNormalization(momentum=0.8)(d)
+            return d
+
+        img_A = Input(shape=self.img_shape)
+        img_B = Input(shape=self.img_shape)
+
+        # Concatenate image and conditioning image by channels to produce input
+        model_input = Concatenate(axis=-1)([img_A, img_B])
+
+        d1 = d_layer(model_input, self.df, bn=False)
+        d2 = d_layer(d1, self.df * 2)
+        d3 = d_layer(d2, self.df * 4)
+        d4 = d_layer(d3, self.df * 8)
+
+        validity = Conv3D(1, kernel_size=4, strides=1, padding='same')(d4)
+
+        return Model([img_A, img_B], validity)
+
+    def train(self, epochs, batch_size=1, sample_interval=50):
+        start_time = datetime.datetime.now()
+        # Adversarial loss ground truths
+        valid = np.zeros((batch_size,) + self.disc_patch)
+        fake = np.ones((batch_size,) + self.disc_patch)
+
+        for epoch in range(epochs):
+            # save model
+            if epoch > 0:
+                print("Saving Models...")
+                self.generator.save(os.path.join(self.modelpath, "G_model.h5"))  # creates a HDF5 file
+                self.discriminator.save(
+                    os.path.join(self.modelpath, "D_model.h5"))  # creates a HDF5 file 'my_model.h5'
+
+            for batch_i, (imgs_A, imgs_B) in enumerate(self.dataloader.load_batch(batch_size)):
+                # ---------------------
+                #  Train Discriminator
+                # ---------------------
+                # Condition on B and generate a translated version
+                fake_A = self.generator.predict([imgs_B])
+
+                # Train the discriminators (original images = real / generated = Fake)
+                d_loss_real = self.discriminator.train_on_batch([imgs_A, imgs_B], valid)
+                d_loss_fake = self.discriminator.train_on_batch([fake_A, imgs_B], fake)
+                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
+
+                # -----------------
+                #  Train Generator
+                # -----------------
+
+                # Train the generators
+                g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])
+                elapsed_time = datetime.datetime.now() - start_time
+                # Plot the progress
+                print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f] time: %s" % (epoch, epochs,
+                                                                                                      batch_i,
+                                                                                                      self.dataloader.n_batches,
+                                                                                                      d_loss[0],
+                                                                                                      100 * d_loss[1],
+                                                                                                      g_loss[0],
+                                                                                                      elapsed_time))
+
+                # If at save interval => save generated image samples
+                if batch_i % sample_interval == 0:
+                    self.show_progress(epoch, batch_i)
+
+    def show_progress(self, epoch, batch_i):
+        filename = "%d_%d.png" % (epoch, batch_i)
+        if self.isInjector:
+            savepath = os.path.join(config['progress'], "injector")
+        else:
+            savepath = os.path.join(config['progress'], "remover")
+        os.makedirs(savepath, exist_ok=True)
+        r, c = 3, 3
+
+        imgs_A, imgs_B = self.dataloader.load_data(batch_size=3, is_testing=True)
+        fake_A = self.generator.predict([imgs_B])
+
+        gen_imgs = np.concatenate([imgs_B, fake_A, imgs_A])
+
+        # Rescale images 0 - 1
+        gen_imgs = 0.5 * gen_imgs + 0.5
+
+        titles = ['Condition', 'Generated', 'Original']
+        fig, axs = plt.subplots(r, c)
+        cnt = 0
+        for i in range(r):
+            for j in range(c):
+                axs[i, j].imshow(gen_imgs[cnt].reshape((self.img_depth, self.img_rows, self.img_cols))[int(self.img_depth/2), :, :])
+                axs[i, j].set_title(titles[i])
+                axs[i, j].axis('off')
+                cnt += 1
+        fig.savefig(os.path.join(savepath, filename))
+        plt.close()
+