Switch to side-by-side view

--- a
+++ b/mediaug/models/pix2pix/pix2pix.py
@@ -0,0 +1,215 @@
+from __future__ import print_function, division
+import scipy
+
+from keras.datasets import mnist
+from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
+from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
+from keras.layers import BatchNormalization, Activation, ZeroPadding2D
+from keras.layers.advanced_activations import LeakyReLU
+from keras.layers.convolutional import UpSampling2D, Conv2D
+from keras.models import Sequential, Model
+from keras.optimizers import Adam
+import datetime
+import matplotlib.pyplot as plt
+import sys
+from data_loader import DataLoader
+import numpy as np
+import os
+
+class Pix2Pix():
+    def __init__(self):
+        # Input shape
+        self.img_rows = 256
+        self.img_cols = 256
+        self.channels = 3
+        self.img_shape = (self.img_rows, self.img_cols, self.channels)
+
+        # Configure data loader
+        self.dataset_name = 'cells'
+        self.data_loader = DataLoader(dataset_name=self.dataset_name,
+                                      img_res=(self.img_rows, self.img_cols))
+
+
+        # Calculate output shape of D (PatchGAN)
+        patch = int(self.img_rows / 2**4)
+        self.disc_patch = (patch, patch, 1)
+
+        # Number of filters in the first layer of G and D
+        self.gf = 64
+        self.df = 64
+
+        optimizer = Adam(0.0002, 0.5)
+
+        # Build and compile the discriminator
+        self.discriminator = self.build_discriminator()
+        self.discriminator.compile(loss='mse',
+            optimizer=optimizer,
+            metrics=['accuracy'])
+
+        #-------------------------
+        # Construct Computational
+        #   Graph of Generator
+        #-------------------------
+
+        # Build the generator
+        self.generator = self.build_generator()
+
+        # Input images and their conditioning images
+        img_A = Input(shape=self.img_shape)
+        img_B = Input(shape=self.img_shape)
+
+        # By conditioning on B generate a fake version of A
+        fake_A = self.generator(img_B)
+
+        # For the combined model we will only train the generator
+        self.discriminator.trainable = False
+
+        # Discriminators determines validity of translated images / condition pairs
+        valid = self.discriminator([fake_A, img_B])
+
+        self.combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A])
+        self.combined.compile(loss=['mse', 'mae'],
+                              loss_weights=[1, 100],
+                              optimizer=optimizer)
+
+    def build_generator(self):
+        """U-Net Generator"""
+
+        def conv2d(layer_input, filters, f_size=4, bn=True):
+            """Layers used during downsampling"""
+            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
+            d = LeakyReLU(alpha=0.2)(d)
+            if bn:
+                d = BatchNormalization(momentum=0.8)(d)
+            return d
+
+        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
+            """Layers used during upsampling"""
+            u = UpSampling2D(size=2)(layer_input)
+            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
+            if dropout_rate:
+                u = Dropout(dropout_rate)(u)
+            u = BatchNormalization(momentum=0.8)(u)
+            u = Concatenate()([u, skip_input])
+            return u
+
+        # Image input
+        d0 = Input(shape=self.img_shape)
+
+        # Downsampling
+        d1 = conv2d(d0, self.gf, bn=False)
+        d2 = conv2d(d1, self.gf*2)
+        d3 = conv2d(d2, self.gf*4)
+        d4 = conv2d(d3, self.gf*8)
+        d5 = conv2d(d4, self.gf*8)
+        d6 = conv2d(d5, self.gf*8)
+        d7 = conv2d(d6, self.gf*8)
+
+        # Upsampling
+        u1 = deconv2d(d7, d6, self.gf*8)
+        u2 = deconv2d(u1, d5, self.gf*8)
+        u3 = deconv2d(u2, d4, self.gf*8)
+        u4 = deconv2d(u3, d3, self.gf*4)
+        u5 = deconv2d(u4, d2, self.gf*2)
+        u6 = deconv2d(u5, d1, self.gf)
+
+        u7 = UpSampling2D(size=2)(u6)
+        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)
+
+        return Model(d0, output_img)
+
+    def build_discriminator(self):
+
+        def d_layer(layer_input, filters, f_size=4, bn=True):
+            """Discriminator layer"""
+            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
+            d = LeakyReLU(alpha=0.2)(d)
+            if bn:
+                d = BatchNormalization(momentum=0.8)(d)
+            return d
+
+        img_A = Input(shape=self.img_shape)
+        img_B = Input(shape=self.img_shape)
+
+        # Concatenate image and conditioning image by channels to produce input
+        combined_imgs = Concatenate(axis=-1)([img_A, img_B])
+
+        d1 = d_layer(combined_imgs, self.df, bn=False)
+        d2 = d_layer(d1, self.df*2)
+        d3 = d_layer(d2, self.df*4)
+        d4 = d_layer(d3, self.df*8)
+
+        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)
+
+        return Model([img_A, img_B], validity)
+
+    def train(self, epochs, batch_size=1, sample_interval=50):
+
+        start_time = datetime.datetime.now()
+
+        # Adversarial loss ground truths
+        valid = np.ones((batch_size,) + self.disc_patch)
+        fake = np.zeros((batch_size,) + self.disc_patch)
+
+        for epoch in range(epochs):
+            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):
+
+                # ---------------------
+                #  Train Discriminator
+                # ---------------------
+
+                # Condition on B and generate a translated version
+                fake_A = self.generator.predict(imgs_B)
+
+                # Train the discriminators (original images = real / generated = Fake)
+                d_loss_real = self.discriminator.train_on_batch([imgs_A, imgs_B], valid)
+                d_loss_fake = self.discriminator.train_on_batch([fake_A, imgs_B], fake)
+                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
+
+                # -----------------
+                #  Train Generator
+                # -----------------
+
+                # Train the generators
+                g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])
+
+                elapsed_time = datetime.datetime.now() - start_time
+                # Plot the progress
+                print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f] time: %s" % (epoch, epochs,
+                                                                        batch_i, self.data_loader.n_batches,
+                                                                        d_loss[0], 100*d_loss[1],
+                                                                        g_loss[0],
+                                                                        elapsed_time))
+
+                # If at save interval => save generated image samples
+                if batch_i % sample_interval == 0:
+                    self.sample_images(epoch, batch_i)
+
+    def sample_images(self, epoch, batch_i):
+        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
+        r, c = 3, 3
+
+        imgs_A, imgs_B = self.data_loader.load_data(batch_size=3, is_testing=True)
+        fake_A = self.generator.predict(imgs_B)
+
+        gen_imgs = np.concatenate([imgs_B, fake_A, imgs_A])
+
+        # Rescale images 0 - 1
+        gen_imgs = 0.5 * gen_imgs + 0.5
+
+        titles = ['Condition', 'Generated', 'Original']
+        fig, axs = plt.subplots(r, c)
+        cnt = 0
+        for i in range(r):
+            for j in range(c):
+                axs[i,j].imshow(gen_imgs[cnt])
+                axs[i, j].set_title(titles[i])
+                axs[i,j].axis('off')
+                cnt += 1
+        fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
+        plt.close()
+
+
+if __name__ == '__main__':
+    gan = Pix2Pix()
+    gan.train(epochs=400, batch_size=1, sample_interval=200)