321 lines (320 with data), 11.6 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "857a5f70",
"metadata": {},
"outputs": [],
"source": [
"# Importing the necessary libraries\n",
"\n",
"import tensorflow as tf\n",
"import albumentations as albu\n",
"import numpy as np\n",
"import gc\n",
"import pickle\n",
"import matplotlib.pyplot as plt\n",
"from keras.callbacks import CSVLogger\n",
"from datetime import datetime\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import jaccard_score, precision_score, recall_score, accuracy_score, f1_score\n",
"from ModelArchitecture.DiceLoss import dice_metric_loss\n",
"from ModelArchitecture import DUCK_Net\n",
"from ImageLoader import ImageLoader2D"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7852dbbf-f3a4-4db6-8c5e-98889bf3abd4",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"# Checking the number of GPUs available\n",
"\n",
"print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9edbe667",
"metadata": {},
"outputs": [],
"source": [
"# Setting the model parameters\n",
"\n",
"img_size = 352\n",
"dataset_type = 'kvasir' # Options: kvasir/cvc-clinicdb/cvc-colondb/etis-laribpolypdb\n",
"learning_rate = 1e-4\n",
"seed_value = 58800\n",
"filters = 17 # Number of filters, the paper presents the results with 17 and 34\n",
"optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)\n",
"\n",
"ct = datetime.now()\n",
"\n",
"model_type = \"DuckNet\"\n",
"\n",
"progress_path = 'ProgressFull/' + dataset_type + '_progress_csv_' + model_type + '_filters_' + str(filters) + '_' + str(ct) + '.csv'\n",
"progressfull_path = 'ProgressFull/' + dataset_type + '_progress_' + model_type + '_filters_' + str(filters) + '_' + str(ct) + '.txt'\n",
"plot_path = 'ProgressFull/' + dataset_type + '_progress_plot_' + model_type + '_filters_' + str(filters) + '_' + str(ct) + '.png'\n",
"model_path = 'ModelSaveTensorFlow/' + dataset_type + '/' + model_type + '_filters_' + str(filters) + '_' + str(ct)\n",
"\n",
"EPOCHS = 600\n",
"min_loss_for_saving = 0.2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "264050d8",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"# Loading the data\n",
"\n",
"X, Y = ImageLoader2D.load_data(img_size, img_size, -1, 'kvasir')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "75b6c029",
"metadata": {},
"outputs": [],
"source": [
"# Splitting the data, seed for reproducibility\n",
"\n",
"x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.1, shuffle= True, random_state = seed_value)\n",
"x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=0.111, shuffle= True, random_state = seed_value)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "92a7b7d0",
"metadata": {},
"outputs": [],
"source": [
"# Defining the augmentations\n",
"\n",
"aug_train = albu.Compose([\n",
" albu.HorizontalFlip(),\n",
" albu.VerticalFlip(),\n",
" albu.ColorJitter(brightness=(0.6,1.6), contrast=0.2, saturation=0.1, hue=0.01, always_apply=True),\n",
" albu.Affine(scale=(0.5,1.5), translate_percent=(-0.125,0.125), rotate=(-180,180), shear=(-22.5,22), always_apply=True),\n",
"])\n",
"\n",
"def augment_images():\n",
" x_train_out = []\n",
" y_train_out = []\n",
"\n",
" for i in range (len(x_train)):\n",
" ug = aug_train(image=x_train[i], mask=y_train[i])\n",
" x_train_out.append(ug['image']) \n",
" y_train_out.append(ug['mask'])\n",
"\n",
" return np.array(x_train_out), np.array(y_train_out)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1609dd32",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"# Creating the model\n",
"\n",
"model = DUCK_Net.create_model(img_height=img_size, img_width=img_size, input_chanels=3, out_classes=1, starting_filters=filters)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "2e513d42",
"metadata": {},
"outputs": [],
"source": [
"# Compiling the model\n",
"\n",
"model.compile(optimizer=optimizer, loss=dice_metric_loss)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ef712b9",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"# Training the model\n",
"\n",
"step = 0\n",
"\n",
"for epoch in range(0, EPOCHS):\n",
" \n",
" print(f'Training, epoch {epoch}')\n",
" print('Learning Rate: ' + str(learning_rate))\n",
"\n",
" step += 1\n",
" \n",
" image_augmented, mask_augmented = augment_images()\n",
" \n",
" csv_logger = CSVLogger(progress_path, append=True, separator=';')\n",
" \n",
" model.fit(x=image_augmented, y=mask_augmented, epochs=1, batch_size=4, validation_data=(x_valid, y_valid), verbose=1, callbacks=[csv_logger])\n",
" \n",
" prediction_valid = model.predict(x_valid, verbose=0)\n",
" loss_valid = dice_metric_loss(y_valid, prediction_valid)\n",
" \n",
" loss_valid = loss_valid.numpy()\n",
" print(\"Loss Validation: \" + str(loss_valid))\n",
" \n",
" prediction_test = model.predict(x_test, verbose=0)\n",
" loss_test = dice_metric_loss(y_test, prediction_test)\n",
" loss_test = loss_test.numpy()\n",
" print(\"Loss Test: \" + str(loss_test))\n",
" \n",
" with open(progressfull_path, 'a') as f:\n",
" f.write('epoch: ' + str(epoch) + '\\nval_loss: ' + str(loss_valid) + '\\ntest_loss: ' + str(loss_test) + '\\n\\n\\n')\n",
" \n",
" if min_loss_for_saving > loss_valid:\n",
" min_loss_for_saving = loss_valid\n",
" print(\"Saved model with val_loss: \", loss_valid)\n",
" model.save(model_path)\n",
" \n",
" del image_augmented\n",
" del mask_augmented\n",
"\n",
" gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2cd52447",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"# Computing the metrics and saving the results\n",
"\n",
"print(\"Loading the model\")\n",
"\n",
"model = tf.keras.models.load_model(model_path, custom_objects={'dice_metric_loss':dice_metric_loss})\n",
"\n",
"prediction_train = model.predict(x_train, batch_size=4)\n",
"prediction_valid = model.predict(x_valid, batch_size=4)\n",
"prediction_test = model.predict(x_test, batch_size=4)\n",
"\n",
"print(\"Predictions done\")\n",
"\n",
"dice_train = f1_score(np.ndarray.flatten(np.array(y_train, dtype=bool)),\n",
" np.ndarray.flatten(prediction_train > 0.5))\n",
"dice_test = f1_score(np.ndarray.flatten(np.array(y_test, dtype=bool)),\n",
" np.ndarray.flatten(prediction_test > 0.5))\n",
"dice_valid = f1_score(np.ndarray.flatten(np.array(y_valid, dtype=bool)),\n",
" np.ndarray.flatten(prediction_valid > 0.5))\n",
"\n",
"print(\"Dice finished\")\n",
"\n",
"\n",
"miou_train = jaccard_score(np.ndarray.flatten(np.array(y_train, dtype=bool)),\n",
" np.ndarray.flatten(prediction_train > 0.5))\n",
"miou_test = jaccard_score(np.ndarray.flatten(np.array(y_test, dtype=bool)),\n",
" np.ndarray.flatten(prediction_test > 0.5))\n",
"miou_valid = jaccard_score(np.ndarray.flatten(np.array(y_valid, dtype=bool)),\n",
" np.ndarray.flatten(prediction_valid > 0.5))\n",
"\n",
"print(\"Miou finished\")\n",
"\n",
"\n",
"precision_train = precision_score(np.ndarray.flatten(np.array(y_train, dtype=bool)),\n",
" np.ndarray.flatten(prediction_train > 0.5))\n",
"precision_test = precision_score(np.ndarray.flatten(np.array(y_test, dtype=bool)),\n",
" np.ndarray.flatten(prediction_test > 0.5))\n",
"precision_valid = precision_score(np.ndarray.flatten(np.array(y_valid, dtype=bool)),\n",
" np.ndarray.flatten(prediction_valid > 0.5))\n",
"\n",
"print(\"Precision finished\")\n",
"\n",
"\n",
"recall_train = recall_score(np.ndarray.flatten(np.array(y_train, dtype=bool)),\n",
" np.ndarray.flatten(prediction_train > 0.5))\n",
"recall_test = recall_score(np.ndarray.flatten(np.array(y_test, dtype=bool)),\n",
" np.ndarray.flatten(prediction_test > 0.5))\n",
"recall_valid = recall_score(np.ndarray.flatten(np.array(y_valid, dtype=bool)),\n",
" np.ndarray.flatten(prediction_valid > 0.5))\n",
"\n",
"print(\"Recall finished\")\n",
"\n",
"\n",
"accuracy_train = accuracy_score(np.ndarray.flatten(np.array(y_train, dtype=bool)),\n",
" np.ndarray.flatten(prediction_train > 0.5))\n",
"accuracy_test = accuracy_score(np.ndarray.flatten(np.array(y_test, dtype=bool)),\n",
" np.ndarray.flatten(prediction_test > 0.5))\n",
"accuracy_valid = accuracy_score(np.ndarray.flatten(np.array(y_valid, dtype=bool)),\n",
" np.ndarray.flatten(prediction_valid > 0.5))\n",
"\n",
"\n",
"print(\"Accuracy finished\")\n",
"\n",
"\n",
"final_file = 'results_' + model_type + '_' + str(filters) + '_' + dataset_type + '.txt'\n",
"print(final_file)\n",
"\n",
"with open(final_file, 'a') as f:\n",
" f.write(dataset_type + '\\n\\n')\n",
" f.write('dice_train: ' + str(dice_train) + ' dice_valid: ' + str(dice_valid) + ' dice_test: ' + str(dice_test) + '\\n\\n')\n",
" f.write('miou_train: ' + str(miou_train) + ' miou_valid: ' + str(miou_valid) + ' miou_test: ' + str(miou_test) + '\\n\\n')\n",
" f.write('precision_train: ' + str(precision_train) + ' precision_valid: ' + str(precision_valid) + ' precision_test: ' + str(precision_test) + '\\n\\n')\n",
" f.write('recall_train: ' + str(recall_train) + ' recall_valid: ' + str(recall_valid) + ' recall_test: ' + str(recall_test) + '\\n\\n')\n",
" f.write('accuracy_train: ' + str(accuracy_train) + ' accuracy_valid: ' + str(accuracy_valid) + ' accuracy_test: ' + str(accuracy_test) + '\\n\\n\\n\\n')\n",
"\n",
"print('File done')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}