--- a
+++ b/code/train-generator.py
@@ -0,0 +1,149 @@
+"""
+Purpose: train a machine learning segmenter that can segment out the nodules on a given 2D patient CT scan slice
+Note:
+- this will train from scratch, with no preloaded weights
+- weights are saved to unet.hdf5 in the specified output folder
+"""
+
+from __future__ import print_function
+import os
+from glob import glob
+import numpy as np
+import matplotlib.pyplot as plt
+from keras.models import Model
+from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D
+from keras.optimizers import Adam
+from keras.optimizers import SGD
+from keras.callbacks import ModelCheckpoint, LearningRateScheduler
+from keras import backend as K
+
+
+TRAIN_PATH = '/home/ubuntu/data/train_pre/'
+VAL_PATH = '/home/ubuntu/data/val_pre/'
+TEST_PATH = '/home/ubuntu/data/test_pre/'
+IMG_ROWS = 512
+IMG_COLS = 512
+
+SMOOTH = 1.
+
+K.set_image_dim_ordering('th')  # Theano dimension ordering in this code
+
+def dice_coef(y_true, y_pred):
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    intersection = K.sum(y_true_f * y_pred_f)
+    return (2. * intersection + SMOOTH) / (K.sum(y_true_f) + K.sum(y_pred_f) + SMOOTH)
+
+def dice_coef_loss(y_true, y_pred):
+    return -dice_coef(y_true, y_pred)
+
+def dice_coef_np(y_true,y_pred):
+    y_true_f = y_true.flatten()
+    y_pred_f = y_pred.flatten()
+    intersection = np.sum(y_true_f * y_pred_f)
+    return (2. * intersection + SMOOTH) / (np.sum(y_true_f) + np.sum(y_pred_f) + SMOOTH)
+
+def get_unet():
+    inputs = Input((1,IMG_ROWS, IMG_COLS))
+    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs)
+    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1)
+    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
+
+    conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool1)
+    conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2)
+    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
+
+    conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2)
+    conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3)
+    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
+
+    conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(pool3)
+    conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv4)
+    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
+
+    conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(pool4)
+    conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(conv5)
+
+    up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)
+    conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(up6)
+    conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv6)
+
+    up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1)
+    conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(up7)
+    conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv7)
+
+    up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)
+    conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up8)
+    conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv8)
+
+    up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1)
+    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up9)
+    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9)
+
+    conv10 = Convolution2D(1, 1, 1, activation='sigmoid')(conv9)
+
+    model = Model(input=inputs, output=conv10)
+
+    model.compile(optimizer=Adam(lr=1.0e-5), loss=dice_coef_loss, metrics=[dice_coef])
+
+    return model
+
+def generator(path,batch_size):
+
+    lung_mask_list = glob(path + 'final_lung_mask_*.npy')
+    nodule_mask_list = glob(path + 'final_nodule_mask_*.npy')
+    lung_mask_list.sort()
+    nodule_mask_list.sort()
+
+    flag = 0
+    start = 0
+    cnt = 0
+
+    while (1):
+        
+        for i in range(len(lung_mask_list)):
+            lung_file = lung_mask_list[i]
+            nodule_file = nodule_mask_list[i]
+            lung = np.load(lung_file)
+            nodule = np.load(nodule_file)
+
+            if(flag):
+                lung_train = np.concatenate((lung_train,lung[0:batch_supply]),axis=0).reshape([batch_size,1,512,512])
+                nodule_train = np.concatenate((nodule_train,nodule[0:batch_supply]),axis=0).reshape([batch_size,1,512,512])
+                yield (lung_train / 255.0, nodule_train / 255.0)
+                start = batch_supply
+                flag = 0
+            while(start + batch_size < len(lung)):
+                lung_train = lung[start:start+batch_size].reshape([batch_size,1,512,512])
+                nodule_train = nodule[start:start+batch_size].reshape([batch_size,1,512,512])
+                yield (lung_train / 255.0, nodule_train / 255.0)
+                start += batch_size
+            if(start + batch_size == len(lung)):
+                lung_train = lung[start:start+batch_size].reshape([batch_size,1,512,512])
+                nodule_train = nodule[start:start+batch_size].reshape([batch_size,1,512,512])
+                yield (lung_train / 255.0, nodule_train / 255.0)
+                flag = 0
+                start = 0
+            else:
+                lung_train = lung[start:]
+                nodule_train = nodule[start:]
+                batch_supply = batch_size - (len(lung)-start)
+                start = 0
+                flag = 1
+
+
+def train_generator(batch_size):
+    model = get_unet()
+    print('model compileover ...')
+    model.fit_generator(generator(TRAIN_PATH,batch_size),steps_per_epoch = 4540, epochs = 2, verbose = 1, validation_data=generator(VAL_PATH,batch_size),validation_steps=650)
+    
+    data_pred = np.load(TEST_PATH+'final_lung_mask_9.npy').reshape([1007,1,512,512])
+    nodule_true = np.load(TEST_PATH+'final_nodule_mask_9.npy').reshape([1007,1,512,512])
+
+    print(model.evaluate(data_pred,nodule_true,batch_size=2,verbose=1))
+
+
+    
+if __name__ == '__main__':
+    batch_size = 2
+    train_generator(batch_size)
\ No newline at end of file