Diff of /Segmentation/train.py [000000] .. [e698c9]

Switch to unified view

a b/Segmentation/train.py
1
from sklearn.metrics import jaccard_score
2
import time
3
import numpy as np
4
import matplotlib.pyplot as plt
5
import os
6
from config import *
7
from metrics import *
8
from data_loader import *
9
os.environ["SM_FRAMEWORK"] = "tf.keras"
10
import segmentation_models as sm
11
sm.framework()
12
13
14
validation_dice_original = np.zeros([valsize,splits])
15
validation_dice_resized = np.zeros([valsize,splits])
16
validation_jaccard_original = np.zeros([valsize,splits])
17
validation_jaccard_resized = np.zeros([valsize,splits])
18
cv_count = 0
19
20
all_history= []
21
for train_index, val_index in kf.split(data_num):
22
23
    #model = get_model(img_size, num_classes)
24
    model = sm.Unet(my_model, encoder_weights="imagenet", input_shape=( 256,256, 3), classes=3, activation='sigmoid')
25
    model.compile(optimizer='Adam', loss=jacard_coef_loss, metrics = [jacard_coef, dice_coef])
26
    history = model.fit(x=batch_generator(batchsize, generate_data(file_list[train_index], image_path, mask_path, gen_type = "train")), epochs=num_epoch, 
27
                            steps_per_epoch=(trainsize/batchsize), 
28
                            validation_steps=(valsize/batchsize),
29
                            validation_data=batch_generator(batchsize, generate_data(file_list[val_index], image_path, mask_path, gen_type = "val")), 
30
                            validation_freq=1, 
31
                            verbose = 1, 
32
                            callbacks=[reduce_lr],
33
                            )
34
    val_gen  = generate_data_pred(file_list[val_index], image_path, mask_path, gen_type = "val")
35
    for i in range(valsize):
36
        time_start = time.time()
37
        original_img, original_mask, X, y_true = next(val_gen)
38
        original_shape = original_img.shape
39
        y_pred = model.predict(np.expand_dims(X,0))
40
        _,y_pred_thr = cv2.threshold(y_pred[0,:,:,0]*255, 127, 255, cv2.THRESH_BINARY)
41
        y_pred = (y_pred_thr/255).astype(int)
42
        dice_resized = dice_score(y_true[:,:,0],y_pred)
43
        jaccard_resized = jaccard_score(y_true[:,:,0],y_pred, average="micro")
44
        
45
        y_pred_original = cv2.resize(y_pred.astype(float), (original_shape[1],original_shape[0]), interpolation= cv2.INTER_LINEAR)
46
        dice_original = dice_score(original_mask[:,:,0]//255,y_pred_original.astype(int))
47
        jaccard_original = jaccard_score(original_mask[:,:,0]//255,y_pred_original.astype(int), average="micro")
48
        
49
        validation_dice_original[i,cv_count] = dice_original
50
        validation_dice_resized[i,cv_count] = dice_resized
51
        validation_jaccard_original[i,cv_count] = jaccard_original
52
        validation_jaccard_resized[i,cv_count] = jaccard_resized
53
        
54
        if i < 5:
55
            plt.figure(figsize=(20,10))
56
            plt.subplot(1,2,1)
57
            plt.imshow(original_img[...,::-1], 'gray', interpolation='none')
58
            plt.imshow(original_mask/255.0, 'jet', interpolation='none', alpha=0.4)
59
            plt.subplot(1,2,2)
60
            plt.imshow(original_img[...,::-1], 'gray', interpolation='none')
61
            plt.imshow(y_pred_original, 'jet', interpolation='none', alpha=0.4)
62
            plt.show()
63
            
64
65
    dice_resized_mean = validation_dice_resized[:,cv_count].mean()
66
    dice_original_mean = validation_dice_original[:,cv_count].mean()
67
    jaccard_resized_mean = validation_jaccard_resized[:,cv_count].mean()
68
    jaccard_original_mean = validation_jaccard_original[:,cv_count].mean()
69
        
70
    print("--------------------------------------")
71
    print("Mean validation DICE (on resized data):", dice_resized_mean) 
72
    print("Mean validation DICE (on original data):", dice_original_mean)
73
    print("--------------------------------------")
74
    print("Mean validation Jaccard (on resized data):", jaccard_resized_mean) 
75
    print("Mean validation Jaccard (on original data):", jaccard_original_mean)
76
    print("--------------------------------------")
77
    runtime = time.time() - time_start 
78
    print('Runtime: {} sec'.format(runtime))
79
    cv_count +=1
80
    all_history.append(history.history["val_dice_coef"])