--- a
+++ b/Segmentation/train.py
@@ -0,0 +1,80 @@
+from sklearn.metrics import jaccard_score
+import time
+import numpy as np
+import matplotlib.pyplot as plt
+import os
+from config import *
+from metrics import *
+from data_loader import *
+os.environ["SM_FRAMEWORK"] = "tf.keras"
+import segmentation_models as sm
+sm.framework()
+
+
+validation_dice_original = np.zeros([valsize,splits])
+validation_dice_resized = np.zeros([valsize,splits])
+validation_jaccard_original = np.zeros([valsize,splits])
+validation_jaccard_resized = np.zeros([valsize,splits])
+cv_count = 0
+
+all_history= []
+for train_index, val_index in kf.split(data_num):
+
+    #model = get_model(img_size, num_classes)
+    model = sm.Unet(my_model, encoder_weights="imagenet", input_shape=( 256,256, 3), classes=3, activation='sigmoid')
+    model.compile(optimizer='Adam', loss=jacard_coef_loss, metrics = [jacard_coef, dice_coef])
+    history = model.fit(x=batch_generator(batchsize, generate_data(file_list[train_index], image_path, mask_path, gen_type = "train")), epochs=num_epoch, 
+                            steps_per_epoch=(trainsize/batchsize), 
+                            validation_steps=(valsize/batchsize),
+                            validation_data=batch_generator(batchsize, generate_data(file_list[val_index], image_path, mask_path, gen_type = "val")), 
+                            validation_freq=1, 
+                            verbose = 1, 
+                            callbacks=[reduce_lr],
+                            )
+    val_gen  = generate_data_pred(file_list[val_index], image_path, mask_path, gen_type = "val")
+    for i in range(valsize):
+        time_start = time.time()
+        original_img, original_mask, X, y_true = next(val_gen)
+        original_shape = original_img.shape
+        y_pred = model.predict(np.expand_dims(X,0))
+        _,y_pred_thr = cv2.threshold(y_pred[0,:,:,0]*255, 127, 255, cv2.THRESH_BINARY)
+        y_pred = (y_pred_thr/255).astype(int)
+        dice_resized = dice_score(y_true[:,:,0],y_pred)
+        jaccard_resized = jaccard_score(y_true[:,:,0],y_pred, average="micro")
+        
+        y_pred_original = cv2.resize(y_pred.astype(float), (original_shape[1],original_shape[0]), interpolation= cv2.INTER_LINEAR)
+        dice_original = dice_score(original_mask[:,:,0]//255,y_pred_original.astype(int))
+        jaccard_original = jaccard_score(original_mask[:,:,0]//255,y_pred_original.astype(int), average="micro")
+        
+        validation_dice_original[i,cv_count] = dice_original
+        validation_dice_resized[i,cv_count] = dice_resized
+        validation_jaccard_original[i,cv_count] = jaccard_original
+        validation_jaccard_resized[i,cv_count] = jaccard_resized
+        
+        if i < 5:
+            plt.figure(figsize=(20,10))
+            plt.subplot(1,2,1)
+            plt.imshow(original_img[...,::-1], 'gray', interpolation='none')
+            plt.imshow(original_mask/255.0, 'jet', interpolation='none', alpha=0.4)
+            plt.subplot(1,2,2)
+            plt.imshow(original_img[...,::-1], 'gray', interpolation='none')
+            plt.imshow(y_pred_original, 'jet', interpolation='none', alpha=0.4)
+            plt.show()
+            
+
+    dice_resized_mean = validation_dice_resized[:,cv_count].mean()
+    dice_original_mean = validation_dice_original[:,cv_count].mean()
+    jaccard_resized_mean = validation_jaccard_resized[:,cv_count].mean()
+    jaccard_original_mean = validation_jaccard_original[:,cv_count].mean()
+        
+    print("--------------------------------------")
+    print("Mean validation DICE (on resized data):", dice_resized_mean) 
+    print("Mean validation DICE (on original data):", dice_original_mean)
+    print("--------------------------------------")
+    print("Mean validation Jaccard (on resized data):", jaccard_resized_mean) 
+    print("Mean validation Jaccard (on original data):", jaccard_original_mean)
+    print("--------------------------------------")
+    runtime = time.time() - time_start 
+    print('Runtime: {} sec'.format(runtime))
+    cv_count +=1
+    all_history.append(history.history["val_dice_coef"])
\ No newline at end of file