|
a |
|
b/Segmentation/train.py |
|
|
1 |
from sklearn.metrics import jaccard_score |
|
|
2 |
import time |
|
|
3 |
import numpy as np |
|
|
4 |
import matplotlib.pyplot as plt |
|
|
5 |
import os |
|
|
6 |
from config import * |
|
|
7 |
from metrics import * |
|
|
8 |
from data_loader import * |
|
|
9 |
os.environ["SM_FRAMEWORK"] = "tf.keras" |
|
|
10 |
import segmentation_models as sm |
|
|
11 |
sm.framework() |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
validation_dice_original = np.zeros([valsize,splits]) |
|
|
15 |
validation_dice_resized = np.zeros([valsize,splits]) |
|
|
16 |
validation_jaccard_original = np.zeros([valsize,splits]) |
|
|
17 |
validation_jaccard_resized = np.zeros([valsize,splits]) |
|
|
18 |
cv_count = 0 |
|
|
19 |
|
|
|
20 |
all_history= [] |
|
|
21 |
for train_index, val_index in kf.split(data_num): |
|
|
22 |
|
|
|
23 |
#model = get_model(img_size, num_classes) |
|
|
24 |
model = sm.Unet(my_model, encoder_weights="imagenet", input_shape=( 256,256, 3), classes=3, activation='sigmoid') |
|
|
25 |
model.compile(optimizer='Adam', loss=jacard_coef_loss, metrics = [jacard_coef, dice_coef]) |
|
|
26 |
history = model.fit(x=batch_generator(batchsize, generate_data(file_list[train_index], image_path, mask_path, gen_type = "train")), epochs=num_epoch, |
|
|
27 |
steps_per_epoch=(trainsize/batchsize), |
|
|
28 |
validation_steps=(valsize/batchsize), |
|
|
29 |
validation_data=batch_generator(batchsize, generate_data(file_list[val_index], image_path, mask_path, gen_type = "val")), |
|
|
30 |
validation_freq=1, |
|
|
31 |
verbose = 1, |
|
|
32 |
callbacks=[reduce_lr], |
|
|
33 |
) |
|
|
34 |
val_gen = generate_data_pred(file_list[val_index], image_path, mask_path, gen_type = "val") |
|
|
35 |
for i in range(valsize): |
|
|
36 |
time_start = time.time() |
|
|
37 |
original_img, original_mask, X, y_true = next(val_gen) |
|
|
38 |
original_shape = original_img.shape |
|
|
39 |
y_pred = model.predict(np.expand_dims(X,0)) |
|
|
40 |
_,y_pred_thr = cv2.threshold(y_pred[0,:,:,0]*255, 127, 255, cv2.THRESH_BINARY) |
|
|
41 |
y_pred = (y_pred_thr/255).astype(int) |
|
|
42 |
dice_resized = dice_score(y_true[:,:,0],y_pred) |
|
|
43 |
jaccard_resized = jaccard_score(y_true[:,:,0],y_pred, average="micro") |
|
|
44 |
|
|
|
45 |
y_pred_original = cv2.resize(y_pred.astype(float), (original_shape[1],original_shape[0]), interpolation= cv2.INTER_LINEAR) |
|
|
46 |
dice_original = dice_score(original_mask[:,:,0]//255,y_pred_original.astype(int)) |
|
|
47 |
jaccard_original = jaccard_score(original_mask[:,:,0]//255,y_pred_original.astype(int), average="micro") |
|
|
48 |
|
|
|
49 |
validation_dice_original[i,cv_count] = dice_original |
|
|
50 |
validation_dice_resized[i,cv_count] = dice_resized |
|
|
51 |
validation_jaccard_original[i,cv_count] = jaccard_original |
|
|
52 |
validation_jaccard_resized[i,cv_count] = jaccard_resized |
|
|
53 |
|
|
|
54 |
if i < 5: |
|
|
55 |
plt.figure(figsize=(20,10)) |
|
|
56 |
plt.subplot(1,2,1) |
|
|
57 |
plt.imshow(original_img[...,::-1], 'gray', interpolation='none') |
|
|
58 |
plt.imshow(original_mask/255.0, 'jet', interpolation='none', alpha=0.4) |
|
|
59 |
plt.subplot(1,2,2) |
|
|
60 |
plt.imshow(original_img[...,::-1], 'gray', interpolation='none') |
|
|
61 |
plt.imshow(y_pred_original, 'jet', interpolation='none', alpha=0.4) |
|
|
62 |
plt.show() |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
dice_resized_mean = validation_dice_resized[:,cv_count].mean() |
|
|
66 |
dice_original_mean = validation_dice_original[:,cv_count].mean() |
|
|
67 |
jaccard_resized_mean = validation_jaccard_resized[:,cv_count].mean() |
|
|
68 |
jaccard_original_mean = validation_jaccard_original[:,cv_count].mean() |
|
|
69 |
|
|
|
70 |
print("--------------------------------------") |
|
|
71 |
print("Mean validation DICE (on resized data):", dice_resized_mean) |
|
|
72 |
print("Mean validation DICE (on original data):", dice_original_mean) |
|
|
73 |
print("--------------------------------------") |
|
|
74 |
print("Mean validation Jaccard (on resized data):", jaccard_resized_mean) |
|
|
75 |
print("Mean validation Jaccard (on original data):", jaccard_original_mean) |
|
|
76 |
print("--------------------------------------") |
|
|
77 |
runtime = time.time() - time_start |
|
|
78 |
print('Runtime: {} sec'.format(runtime)) |
|
|
79 |
cv_count +=1 |
|
|
80 |
all_history.append(history.history["val_dice_coef"]) |