1220 lines (1219 with data), 281.3 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "nDXW7A7pwZ7a"
},
"source": [
"# UNET SEGMENTATION (SUJAL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M5oF6qqZwZ70"
},
"outputs": [],
"source": [
"## Imports\n",
"import os\n",
"import sys\n",
"import random\n",
"\n",
"import numpy as np\n",
"import cv2\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"\n",
"## Seeding\n",
"seed = 42\n",
"random.seed = seed\n",
"np.random.seed = seed\n",
"tf.seed = seed"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "70jPovBLwZ7-"
},
"source": [
"## Data "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SvRbQp8LwZ8B"
},
"outputs": [],
"source": [
"class DataGen(keras.utils.Sequence):\n",
" def __init__(self, ids, path, batch_size=8, image_size=128):\n",
" self.ids = ids\n",
" self.path = path\n",
" self.batch_size = batch_size\n",
" self.image_size = image_size\n",
" self.on_epoch_end()\n",
"\n",
" def __load__(self, id_name):\n",
" # Construct image path\n",
" image_path = os.path.join(self.path, \"images\", id_name)\n",
"\n",
" # Construct mask path\n",
" mask_path = os.path.join(self.path, \"masks\", id_name)\n",
"\n",
" # Read and resize image\n",
" image = cv2.imread(image_path, 1)\n",
" image = cv2.resize(image, (self.image_size, self.image_size))\n",
"\n",
" # Read and resize mask\n",
" mask = cv2.imread(mask_path, 0) # mask often is a single-channel image\n",
" mask = cv2.resize(mask, (self.image_size, self.image_size))\n",
" mask = np.expand_dims(mask, axis=-1)\n",
"\n",
" # Normalize\n",
" image = image / 255.0\n",
" mask = mask / 255.0\n",
"\n",
" return image, mask\n",
"\n",
" def __getitem__(self, index):\n",
" if(index+1)*self.batch_size > len(self.ids):\n",
" self.batch_size = len(self.ids) - index*self.batch_size\n",
"\n",
" files_batch = self.ids[index*self.batch_size : (index+1)*self.batch_size]\n",
"\n",
" image = []\n",
" mask = []\n",
"\n",
" for id_name in files_batch:\n",
" _img, _mask = self.__load__(id_name)\n",
" image.append(_img)\n",
" mask.append(_mask)\n",
"\n",
" image = np.array(image)\n",
" mask = np.array(mask)\n",
"\n",
" return image, mask\n",
"\n",
" def on_epoch_end(self):\n",
" pass\n",
"\n",
" def __len__(self):\n",
" return int(np.ceil(len(self.ids)/float(self.batch_size)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "lcEkbVTQWGv-",
"outputId": "5df2d48b-d9fc-47c6-d2d4-75d385054910"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mounted at /content/drive\n"
]
}
],
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rUG6cDK1wZ8I"
},
"source": [
"## Hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2QwbWknPwZ8M"
},
"outputs": [],
"source": [
"image_size = 128\n",
"train_path = \"/content/drive/MyDrive/UNet Dataset/Cardiac Dataset/train\"\n",
"epochs = 30\n",
"batch_size = 8\n",
"\n",
"## Training Ids\n",
"train_ids = os.listdir(os.path.join(train_path, \"images\"))\n",
"\n",
"\n",
"## Validation Data Size\n",
"val_data_size = 10\n",
"\n",
"valid_ids = train_ids[:val_data_size]\n",
"train_ids = train_ids[val_data_size:]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6_cMh29vwZ8R",
"outputId": "534e3fce-2174-4f8d-99d3-c19550f24f70"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(8, 128, 128, 3) (8, 128, 128, 1)\n"
]
}
],
"source": [
"gen = DataGen(train_ids, train_path, batch_size=batch_size, image_size=image_size)\n",
"x, y = gen.__getitem__(0)\n",
"print(x.shape, y.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 309
},
"id": "NCG2jWo9wZ8c",
"outputId": "f158147f-b3b7-4c7f-e0a0-13219d136aa7",
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7caa566820b0>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"r = random.randint(0, len(x)-1)\n",
"\n",
"fig = plt.figure()\n",
"fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
"ax = fig.add_subplot(1, 2, 1)\n",
"ax.imshow(x[r])\n",
"ax = fig.add_subplot(1, 2, 2)\n",
"ax.imshow(np.reshape(y[r], (image_size, image_size)), cmap=\"gray\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1xJMvRqswZ8o"
},
"source": [
"## Different Convolutional Blocks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xtA9Rr75wZ8p"
},
"outputs": [],
"source": [
"def down_block(x, filters, kernel_size=(3, 3), padding=\"same\", strides=1):\n",
" c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(x)\n",
" c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(c)\n",
" p = keras.layers.MaxPool2D((2, 2), (2, 2))(c)\n",
" return c, p\n",
"\n",
"def up_block(x, skip, filters, kernel_size=(3, 3), padding=\"same\", strides=1):\n",
" us = keras.layers.UpSampling2D((2, 2))(x)\n",
" concat = keras.layers.Concatenate()([us, skip])\n",
" c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(concat)\n",
" c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(c)\n",
" return c\n",
"\n",
"def bottleneck(x, filters, kernel_size=(3, 3), padding=\"same\", strides=1):\n",
" c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(x)\n",
" c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(c)\n",
" return c"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QMdNQ_HVwZ82"
},
"source": [
"## UNet Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YAa1RbB8wZ89"
},
"outputs": [],
"source": [
"def UNet():\n",
" f = [16, 32, 64, 128, 256]\n",
" inputs = keras.layers.Input((image_size, image_size, 3))\n",
"\n",
" p0 = inputs\n",
" c1, p1 = down_block(p0, f[0]) # 128 --> 64\n",
" c2, p2 = down_block(p1, f[1]) # 64 --> 32\n",
" c3, p3 = down_block(p2, f[2]) # 32 --> 16\n",
" c4, p4 = down_block(p3, f[3]) # 16 --> 8\n",
"\n",
" bn = bottleneck(p4, f[4])\n",
"\n",
" u1 = up_block(bn, c4, f[3]) # 8 --> 16\n",
" u2 = up_block(u1, c3, f[2]) # 16 --> 32\n",
" u3 = up_block(u2, c2, f[1]) # 32 --> 64\n",
" u4 = up_block(u3, c1, f[0]) # 64 --> 128\n",
"\n",
" outputs = keras.layers.Conv2D(1, (1, 1), padding=\"same\", activation=\"sigmoid\")(u4)\n",
" model = keras.models.Model(inputs, outputs)\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "ntBPY06awZ9E",
"outputId": "bc8241cf-1ca4-46a5-ddf1-afcb43ddc395",
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"functional\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">3</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">448</span> │ input_layer[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">2,320</span> │ conv2d[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ max_pooling2d │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv2d_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">4,640</span> │ max_pooling2d[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">9,248</span> │ conv2d_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ max_pooling2d_1 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv2d_3[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_4 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">18,496</span> │ max_pooling2d_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_5 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ max_pooling2d_2 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv2d_5[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_6 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">73,856</span> │ max_pooling2d_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_7 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">147,584</span> │ conv2d_6[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ max_pooling2d_3 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">8</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">8</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv2d_7[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_8 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">8</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">8</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">295,168</span> │ max_pooling2d_3[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_9 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">8</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">8</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">590,080</span> │ conv2d_8[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ up_sampling2d │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv2d_9[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">UpSampling2D</span>) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ concatenate (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">384</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ up_sampling2d[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
"│ │ │ │ conv2d_7[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_10 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">442,496</span> │ concatenate[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_11 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">147,584</span> │ conv2d_10[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ up_sampling2d_1 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv2d_11[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">UpSampling2D</span>) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ concatenate_1 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">192</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ up_sampling2d_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ conv2d_5[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_12 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">110,656</span> │ concatenate_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_13 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">36,928</span> │ conv2d_12[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ up_sampling2d_2 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv2d_13[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">UpSampling2D</span>) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ concatenate_2 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">96</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ up_sampling2d_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ conv2d_3[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_14 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">27,680</span> │ concatenate_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_15 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">9,248</span> │ conv2d_14[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ up_sampling2d_3 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv2d_15[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">UpSampling2D</span>) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ concatenate_3 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">48</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ up_sampling2d_3[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ conv2d_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_16 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">6,928</span> │ concatenate_3[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_17 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">2,320</span> │ conv2d_16[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_18 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">17</span> │ conv2d_17[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m448\u001b[0m │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_1 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m2,320\u001b[0m │ conv2d[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ max_pooling2d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv2d_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_2 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m4,640\u001b[0m │ max_pooling2d[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_3 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m9,248\u001b[0m │ conv2d_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ max_pooling2d_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv2d_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_4 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m18,496\u001b[0m │ max_pooling2d_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_5 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m36,928\u001b[0m │ conv2d_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ max_pooling2d_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv2d_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_6 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m73,856\u001b[0m │ max_pooling2d_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_7 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m147,584\u001b[0m │ conv2d_6[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ max_pooling2d_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8\u001b[0m, \u001b[38;5;34m8\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv2d_7[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_8 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8\u001b[0m, \u001b[38;5;34m8\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m295,168\u001b[0m │ max_pooling2d_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_9 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8\u001b[0m, \u001b[38;5;34m8\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m590,080\u001b[0m │ conv2d_8[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ up_sampling2d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv2d_9[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mUpSampling2D\u001b[0m) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ concatenate (\u001b[38;5;33mConcatenate\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m384\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ up_sampling2d[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ │ │ │ conv2d_7[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_10 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m442,496\u001b[0m │ concatenate[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_11 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m147,584\u001b[0m │ conv2d_10[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ up_sampling2d_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv2d_11[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mUpSampling2D\u001b[0m) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ concatenate_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m192\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ up_sampling2d_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ conv2d_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_12 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m110,656\u001b[0m │ concatenate_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_13 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m36,928\u001b[0m │ conv2d_12[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ up_sampling2d_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv2d_13[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mUpSampling2D\u001b[0m) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ concatenate_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m96\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ up_sampling2d_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ conv2d_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_14 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m27,680\u001b[0m │ concatenate_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_15 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m64\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m9,248\u001b[0m │ conv2d_14[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ up_sampling2d_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv2d_15[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mUpSampling2D\u001b[0m) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ concatenate_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m48\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ up_sampling2d_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ conv2d_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_16 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m6,928\u001b[0m │ concatenate_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_17 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m2,320\u001b[0m │ conv2d_16[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_18 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m128\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m17\u001b[0m │ conv2d_17[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,962,625</span> (7.49 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,962,625\u001b[0m (7.49 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,962,625</span> (7.49 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,962,625\u001b[0m (7.49 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = UNet()\n",
"model.compile(optimizer=\"adam\", loss=\"binary_crossentropy\", metrics=[\"acc\"])\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nmkxq4-BwZ9Q"
},
"source": [
"## Training the model"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6I34e4tcwZ9T",
"outputId": "427bd73a-0c95-4b03-ce95-6584de960233",
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/30\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:122: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.\n",
" self._warn_if_super_not_called()\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m631s\u001b[0m 3s/step - acc: 0.8556 - loss: 0.3451 - val_acc: 0.9528 - val_loss: 0.0973\n",
"Epoch 2/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 679us/step - acc: 0.9566 - loss: 0.0906 - val_acc: 0.9547 - val_loss: 0.1025\n",
"Epoch 3/30\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/lib/python3.10/contextlib.py:153: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.\n",
" self.gen.throw(typ, value, traceback)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m215s\u001b[0m 727ms/step - acc: 0.9470 - loss: 0.1336 - val_acc: 0.9685 - val_loss: 0.0766\n",
"Epoch 4/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m156s\u001b[0m 735ms/step - acc: 0.9662 - loss: 0.0772 - val_acc: 0.9793 - val_loss: 0.0495\n",
"Epoch 5/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m110s\u001b[0m 516ms/step - acc: 0.9653 - loss: 0.0823 - val_acc: 0.9794 - val_loss: 0.0416\n",
"Epoch 6/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m155s\u001b[0m 725ms/step - acc: 0.9696 - loss: 0.0687\n",
"Epoch 7/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m158s\u001b[0m 741ms/step - acc: 0.9718 - loss: 0.0647 - val_acc: 0.9770 - val_loss: 0.0795\n",
"Epoch 8/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m103s\u001b[0m 484ms/step - acc: 0.9697 - loss: 0.0692 - val_acc: 0.9852 - val_loss: 0.0317\n",
"Epoch 9/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m188s\u001b[0m 762ms/step - acc: 0.9759 - loss: 0.0537 - val_acc: 0.9818 - val_loss: 0.0413\n",
"Epoch 10/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m156s\u001b[0m 729ms/step - acc: 0.9766 - loss: 0.0518 - val_acc: 0.9870 - val_loss: 0.0268\n",
"Epoch 11/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m106s\u001b[0m 495ms/step - acc: 0.9774 - loss: 0.0493 - val_acc: 0.9907 - val_loss: 0.0196\n",
"Epoch 12/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m174s\u001b[0m 735ms/step - acc: 0.9782 - loss: 0.0475\n",
"Epoch 13/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m164s\u001b[0m 768ms/step - acc: 0.9804 - loss: 0.0416 - val_acc: 0.9888 - val_loss: 0.0219\n",
"Epoch 14/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m106s\u001b[0m 499ms/step - acc: 0.9798 - loss: 0.0431 - val_acc: 0.9835 - val_loss: 0.0333\n",
"Epoch 15/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m156s\u001b[0m 733ms/step - acc: 0.9802 - loss: 0.0433 - val_acc: 0.9824 - val_loss: 0.0368\n",
"Epoch 16/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m158s\u001b[0m 742ms/step - acc: 0.9822 - loss: 0.0369 - val_acc: 0.9895 - val_loss: 0.0231\n",
"Epoch 17/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 489ms/step - acc: 0.9821 - loss: 0.0374 - val_acc: 0.9927 - val_loss: 0.0176\n",
"Epoch 18/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m186s\u001b[0m 762ms/step - acc: 0.9823 - loss: 0.0369\n",
"Epoch 19/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m154s\u001b[0m 722ms/step - acc: 0.9827 - loss: 0.0366 - val_acc: 0.9879 - val_loss: 0.0248\n",
"Epoch 20/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m103s\u001b[0m 482ms/step - acc: 0.9820 - loss: 0.0384 - val_acc: 0.9872 - val_loss: 0.0247\n",
"Epoch 21/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m156s\u001b[0m 728ms/step - acc: 0.9836 - loss: 0.0339 - val_acc: 0.9885 - val_loss: 0.0257\n",
"Epoch 22/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m161s\u001b[0m 756ms/step - acc: 0.9826 - loss: 0.0369 - val_acc: 0.9905 - val_loss: 0.0205\n",
"Epoch 23/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 490ms/step - acc: 0.9843 - loss: 0.0324 - val_acc: 0.9920 - val_loss: 0.0178\n",
"Epoch 24/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m157s\u001b[0m 734ms/step - acc: 0.9848 - loss: 0.0307\n",
"Epoch 25/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m156s\u001b[0m 733ms/step - acc: 0.9840 - loss: 0.0330 - val_acc: 0.9902 - val_loss: 0.0202\n",
"Epoch 26/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m103s\u001b[0m 483ms/step - acc: 0.9852 - loss: 0.0306 - val_acc: 0.9871 - val_loss: 0.0251\n",
"Epoch 27/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m164s\u001b[0m 767ms/step - acc: 0.9842 - loss: 0.0333 - val_acc: 0.9803 - val_loss: 0.0453\n",
"Epoch 28/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m156s\u001b[0m 731ms/step - acc: 0.9821 - loss: 0.0389 - val_acc: 0.9901 - val_loss: 0.0196\n",
"Epoch 29/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 485ms/step - acc: 0.9840 - loss: 0.0330 - val_acc: 0.9919 - val_loss: 0.0153\n",
"Epoch 30/30\n",
"\u001b[1m213/213\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m171s\u001b[0m 713ms/step - acc: 0.9847 - loss: 0.0307\n"
]
},
{
"data": {
"text/plain": [
"<keras.src.callbacks.history.History at 0x7caa5404be20>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_gen = DataGen(train_ids, train_path, image_size=image_size, batch_size=batch_size)\n",
"valid_gen = DataGen(valid_ids, train_path, image_size=image_size, batch_size=batch_size)\n",
"\n",
"train_steps = len(train_ids)//batch_size\n",
"valid_steps = len(valid_ids)//batch_size\n",
"\n",
"history = model.fit(train_gen, validation_data=valid_gen, steps_per_epoch=train_steps,\n",
" validation_steps=valid_steps, epochs=epochs)\n",
"\n",
"# Save the weights\n",
"model.save_weights(\"UNetW.weights.h5\")\n",
"\n",
"# Plot training & validation accuracy and loss\n",
"acc = history.history['acc']\n",
"val_acc = history.history['val_acc']\n",
"loss = history.history['loss']\n",
"val_loss = history.history['val_loss']\n",
"num_epochs = min(len(acc), len(val_acc))\n",
"epochs_range = range(1, num_epochs + 1)\n",
"\n",
"plt.figure(figsize=(16, 6))\n",
"\n",
"# Accuracy plot\n",
"plt.subplot(1, 2, 1)\n",
"plt.plot(epochs_range, acc[:num_epochs], label='Training Accuracy')\n",
"plt.plot(epochs_range, val_acc[:num_epochs], label='Validation Accuracy')\n",
"plt.title('Training & Validation Accuracy')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Accuracy')\n",
"plt.legend()\n",
"\n",
"# Loss plot\n",
"plt.subplot(1, 2, 2)\n",
"plt.plot(epochs_range, loss[:num_epochs], label='Training Loss')\n",
"plt.plot(epochs_range, val_loss[:num_epochs], label='Validation Loss')\n",
"plt.title('Training & Validation Loss')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Loss')\n",
"plt.legend()\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iDUykT_awZ9h"
},
"source": [
"## Testing the model"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "i_oY4H3LwZ9l",
"outputId": "371b3f40-25d0-416c-fd0f-e3eb6b0df943"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 543ms/step\n"
]
}
],
"source": [
"## Save the Weights\n",
"model.save_weights(\"UNetW.weights.h5\")\n",
"\n",
"## Dataset for prediction\n",
"x, y = valid_gen.__getitem__(2)\n",
"result = model.predict(x)\n",
"\n",
"result = result > 0.5"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DLdgSxZKpQcQ"
},
"outputs": [],
"source": [
"\n",
"# Load the weights\n",
"model.load_weights(\"UNetW.weights.h5\")\n",
"\n",
"# Now you can use the model for predictions or further training\n",
"x, y = valid_gen.__getitem__(2)\n",
"result = model.predict(x)\n",
"result = result > 0.5"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 309
},
"id": "uSdH42qVwZ9q",
"outputId": "95c65b4e-5f31-4a9e-88a1-d416166ec239"
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7caa43522c80>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure()\n",
"fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
"\n",
"ax = fig.add_subplot(1, 2, 1)\n",
"ax.imshow(np.reshape(y[0]*255, (image_size, image_size)), cmap=\"gray\")\n",
"\n",
"ax = fig.add_subplot(1, 2, 2)\n",
"ax.imshow(np.reshape(result[0]*255, (image_size, image_size)), cmap=\"gray\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 309
},
"id": "1OKwJGbVwZ-A",
"outputId": "8ef32166-359f-4023-f3a4-f12aeb895db0"
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7caa56311c30>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure()\n",
"fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
"\n",
"ax = fig.add_subplot(1, 2, 1)\n",
"ax.imshow(np.reshape(y[1]*255, (image_size, image_size)), cmap=\"gray\")\n",
"\n",
"ax = fig.add_subplot(1, 2, 2)\n",
"ax.imshow(np.reshape(result[1]*255, (image_size, image_size)), cmap=\"gray\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CNN Model (Madhurya)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "xBk2k9vzKRD5",
"outputId": "e9880f4e-a1a6-4c3b-aac7-bacca57154e5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset shape : (1717, 64, 64, 1) (1717, 32, 32, 1)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n",
" super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model Summary:\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
"│ conv2d_19 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">54</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">54</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">100</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">12,200</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ average_pooling2d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">AveragePooling2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">100</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ reshape (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Reshape</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">8100</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1024</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">8,295,424</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ reshape_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Reshape</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
"│ conv2d_19 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m54\u001b[0m, \u001b[38;5;34m54\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m12,200\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ average_pooling2d (\u001b[38;5;33mAveragePooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ reshape (\u001b[38;5;33mReshape\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8100\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1024\u001b[0m) │ \u001b[38;5;34m8,295,424\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ reshape_1 (\u001b[38;5;33mReshape\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">8,307,624</span> (31.69 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m8,307,624\u001b[0m (31.69 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">8,307,624</span> (31.69 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m8,307,624\u001b[0m (31.69 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m24s\u001b[0m 260ms/step - accuracy: 0.9048 - loss: 0.1698 - val_accuracy: 0.9131 - val_loss: 0.0750\n",
"Epoch 2/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 266ms/step - accuracy: 0.9424 - loss: 0.0506 - val_accuracy: 0.9195 - val_loss: 0.0692\n",
"Epoch 3/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m40s\u001b[0m 260ms/step - accuracy: 0.9451 - loss: 0.0481 - val_accuracy: 0.9209 - val_loss: 0.0683\n",
"Epoch 4/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m25s\u001b[0m 291ms/step - accuracy: 0.9532 - loss: 0.0424 - val_accuracy: 0.9235 - val_loss: 0.0640\n",
"Epoch 5/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m38s\u001b[0m 262ms/step - accuracy: 0.9594 - loss: 0.0375 - val_accuracy: 0.9282 - val_loss: 0.0603\n",
"Epoch 6/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 262ms/step - accuracy: 0.9649 - loss: 0.0328 - val_accuracy: 0.9348 - val_loss: 0.0550\n",
"Epoch 7/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 265ms/step - accuracy: 0.9682 - loss: 0.0304 - val_accuracy: 0.9350 - val_loss: 0.0541\n",
"Epoch 8/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 262ms/step - accuracy: 0.9701 - loss: 0.0286 - val_accuracy: 0.9352 - val_loss: 0.0531\n",
"Epoch 9/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m25s\u001b[0m 296ms/step - accuracy: 0.9712 - loss: 0.0270 - val_accuracy: 0.9343 - val_loss: 0.0536\n",
"Epoch 10/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m39s\u001b[0m 277ms/step - accuracy: 0.9725 - loss: 0.0262 - val_accuracy: 0.9375 - val_loss: 0.0515\n",
"Epoch 11/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m39s\u001b[0m 260ms/step - accuracy: 0.9727 - loss: 0.0255 - val_accuracy: 0.9396 - val_loss: 0.0502\n",
"Epoch 12/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 258ms/step - accuracy: 0.9733 - loss: 0.0247 - val_accuracy: 0.9399 - val_loss: 0.0493\n",
"Epoch 13/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m25s\u001b[0m 288ms/step - accuracy: 0.9743 - loss: 0.0239 - val_accuracy: 0.9381 - val_loss: 0.0498\n",
"Epoch 14/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m23s\u001b[0m 266ms/step - accuracy: 0.9742 - loss: 0.0237 - val_accuracy: 0.9384 - val_loss: 0.0498\n",
"Epoch 15/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 268ms/step - accuracy: 0.9752 - loss: 0.0230 - val_accuracy: 0.9387 - val_loss: 0.0498\n",
"Epoch 16/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m25s\u001b[0m 289ms/step - accuracy: 0.9754 - loss: 0.0227 - val_accuracy: 0.9392 - val_loss: 0.0496\n",
"Epoch 17/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m23s\u001b[0m 262ms/step - accuracy: 0.9752 - loss: 0.0227 - val_accuracy: 0.9382 - val_loss: 0.0500\n",
"Epoch 18/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m25s\u001b[0m 289ms/step - accuracy: 0.9758 - loss: 0.0221 - val_accuracy: 0.9403 - val_loss: 0.0489\n",
"Epoch 19/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m39s\u001b[0m 271ms/step - accuracy: 0.9764 - loss: 0.0214 - val_accuracy: 0.9387 - val_loss: 0.0491\n",
"Epoch 20/20\n",
"\u001b[1m86/86\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m40s\u001b[0m 264ms/step - accuracy: 0.9763 - loss: 0.0217 - val_accuracy: 0.9396 - val_loss: 0.0494\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m108/108\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 67ms/step\n",
"y_pred shape: (1717, 32, 32, 1)\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 800x400 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import os\n",
"import cv2\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import keras\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Activation, Flatten, Reshape\n",
"from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D\n",
"from keras import regularizers\n",
"\n",
"# Set your data directory\n",
"data_dir = '/content/drive/MyDrive/UNet Dataset/Cardiac Dataset/train'\n",
"image_dir = os.path.join(data_dir, \"images\")\n",
"mask_dir = os.path.join(data_dir, \"masks\")\n",
"\n",
"# List files\n",
"image_names = sorted(os.listdir(image_dir))\n",
"image_names = [name for name in image_names if name in mask_names]\n",
"mask_names = sorted(os.listdir(mask_dir))\n",
"\n",
"# Make sure that image_names and mask_names correspond one-to-one\n",
"# If they differ, ensure file naming consistency before running.\n",
"assert len(image_names) == len(mask_names), \"Number of images and masks do not match.\"\n",
"\n",
"# Desired shapes\n",
"input_image_shape = (64, 64) # For the CNN input\n",
"mask_shape = (32, 32) # For the CNN output\n",
"\n",
"X = []\n",
"Y = []\n",
"\n",
"for img_name, msk_name in zip(image_names, mask_names):\n",
" img_path = os.path.join(image_dir, img_name)\n",
" msk_path = os.path.join(mask_dir, msk_name)\n",
"\n",
" # Read images in grayscale\n",
" img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)\n",
" msk = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)\n",
"\n",
" # Resize to desired shapes\n",
" img = cv2.resize(img, input_image_shape)\n",
" msk = cv2.resize(msk, mask_shape)\n",
"\n",
" # Normalize (if you want to normalize)\n",
" img = img.astype(np.float32) / 255.0\n",
" msk = msk.astype(np.float32) / 255.0\n",
"\n",
" # Add channel dimension\n",
" img = np.expand_dims(img, axis=-1) # (64,64,1)\n",
" msk = np.expand_dims(msk, axis=-1) # (32,32,1)\n",
"\n",
" X.append(img)\n",
" Y.append(msk)\n",
"\n",
"X = np.array(X)\n",
"Y = np.array(Y)\n",
"\n",
"print('Dataset shape :', X.shape, Y.shape) # X: (N,64,64,1), Y: (N,32,32,1)\n",
"\n",
"# Create the CNN model\n",
"def create_model(input_shape=(64, 64, 1)):\n",
" \"\"\"\n",
" Simple convnet model: one convolution, one average pooling and one fully connected layer\n",
" ending with a reshape to (32,32,1).\n",
" \"\"\"\n",
" model = Sequential()\n",
" # Conv layer\n",
" model.add(Conv2D(100, (11,11), padding='valid', strides=(1, 1), input_shape=input_shape))\n",
" # Average Pooling\n",
" model.add(AveragePooling2D((6,6)))\n",
" # Flatten/Reshape step\n",
" # After Conv+Pool:\n",
" # Input: (64x64x1) -> Conv(11x11): (54x54x100) -> AvgPool(6x6): (9x9x100) = 8100 features\n",
" model.add(Reshape((8100,))) # Flatten to (8100,)\n",
" model.add(Dense(1024, activation='sigmoid', kernel_regularizer=regularizers.l2(0.0001)))\n",
" # Now we have (1024,). We want (32,32,1):\n",
" # 32*32 = 1024, so we reshape to (32,32,1)\n",
" model.add(Reshape((32,32,1)))\n",
" return model\n",
"\n",
"m = create_model()\n",
"m.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])\n",
"print('Model Summary:')\n",
"m.summary()\n",
"\n",
"# Train the model\n",
"epochs = 20\n",
"batch_size = 16\n",
"history = m.fit(X, Y, batch_size=batch_size, epochs=epochs, validation_split=0.2)\n",
"\n",
"# Plot training and validation loss\n",
"plt.figure(figsize=(10,5))\n",
"plt.plot(history.history['loss'], label='Training Loss')\n",
"plt.plot(history.history['val_loss'], label='Validation Loss', linestyle='--')\n",
"plt.title(\"Learning Curve\")\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.legend()\n",
"plt.show()\n",
"\n",
"# Save the weights\n",
"m.save_weights(\"UNetW.weights.h5\")\n",
"\n",
"# Example prediction\n",
"y_pred = m.predict(X, batch_size=batch_size)\n",
"print(\"y_pred shape:\", y_pred.shape)\n",
"\n",
"# Visualize a sample\n",
"idx = 0\n",
"fig, ax = plt.subplots(1, 2, figsize=(8,4))\n",
"ax[0].imshow(X[idx].reshape(64,64), cmap='gray')\n",
"ax[0].set_title('Input Image')\n",
"ax[1].imshow(y_pred[idx].reshape(32,32), cmap='gray')\n",
"ax[1].set_title('Predicted Mask')\n",
"plt.show()\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 0
}