Diff of /UNet_CNN_AddUNet.ipynb [000000] .. [c1eed3]

Switch to side-by-side view

--- a
+++ b/UNet_CNN_AddUNet.ipynb
@@ -0,0 +1,780 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "nDXW7A7pwZ7a"
+   },
+   "source": [
+    "# UNET SEGMENTATION (SUJAL)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "M5oF6qqZwZ70"
+   },
+   "outputs": [],
+   "source": [
+    "## Imports\n",
+    "import os\n",
+    "import sys\n",
+    "import random\n",
+    "\n",
+    "import numpy as np\n",
+    "import cv2\n",
+    "import matplotlib.pyplot as plt\n",
+    "\n",
+    "import tensorflow as tf\n",
+    "from tensorflow import keras\n",
+    "\n",
+    "## Seeding\n",
+    "seed = 42\n",
+    "random.seed = seed\n",
+    "np.random.seed = seed\n",
+    "tf.seed = seed"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "70jPovBLwZ7-"
+   },
+   "source": [
+    "## Data "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "SvRbQp8LwZ8B"
+   },
+   "outputs": [],
+   "source": [
+    "    class DataGen(keras.utils.Sequence):\n",
+    "        def __init__(self, ids, path, batch_size=8, image_size=128):\n",
+    "            self.ids = ids\n",
+    "            self.path = path\n",
+    "            self.batch_size = batch_size\n",
+    "            self.image_size = image_size\n",
+    "            self.on_epoch_end()\n",
+    "\n",
+    "        def __load__(self, id_name):\n",
+    "            # Construct image path\n",
+    "            image_path = os.path.join(self.path, \"images\", id_name)\n",
+    "\n",
+    "            # Construct mask path\n",
+    "            mask_path = os.path.join(self.path, \"masks\", id_name)\n",
+    "\n",
+    "            # Read and validate image\n",
+    "            image = cv2.imread(image_path, 1)\n",
+    "            if image is None:\n",
+    "                print(f\"Warning: Failed to load image: {image_path}\")\n",
+    "                return None, None\n",
+    "\n",
+    "            image = cv2.resize(image, (self.image_size, self.image_size))\n",
+    "\n",
+    "            # Read and validate mask\n",
+    "            mask = cv2.imread(mask_path, 0)  # mask often is a single-channel image\n",
+    "            if mask is None:\n",
+    "                print(f\"Warning: Failed to load mask: {mask_path}\")\n",
+    "                return None, None\n",
+    "\n",
+    "            mask = cv2.resize(mask, (self.image_size, self.image_size))\n",
+    "            mask = np.expand_dims(mask, axis=-1)\n",
+    "\n",
+    "            # Normalize\n",
+    "            image = image / 255.0\n",
+    "            mask = mask / 255.0\n",
+    "\n",
+    "            return image, mask\n",
+    "\n",
+    "        def __getitem__(self, index):\n",
+    "            if (index + 1) * self.batch_size > len(self.ids):\n",
+    "                self.batch_size = len(self.ids) - index * self.batch_size\n",
+    "\n",
+    "            files_batch = self.ids[index * self.batch_size: (index + 1) * self.batch_size]\n",
+    "\n",
+    "            image = []\n",
+    "            mask = []\n",
+    "\n",
+    "            for id_name in files_batch:\n",
+    "                _img, _mask = self.__load__(id_name)\n",
+    "                if _img is not None and _mask is not None:  # Only add valid samples\n",
+    "                    image.append(_img)\n",
+    "                    mask.append(_mask)\n",
+    "\n",
+    "            if len(image) == 0:  # Handle case where all files in the batch fail\n",
+    "                raise ValueError(\"No valid images or masks found in the current batch.\")\n",
+    "\n",
+    "            image = np.array(image)\n",
+    "            mask = np.array(mask)\n",
+    "\n",
+    "            return image, mask\n",
+    "\n",
+    "        def on_epoch_end(self):\n",
+    "            pass\n",
+    "\n",
+    "        def __len__(self):\n",
+    "            return int(np.ceil(len(self.ids) / float(self.batch_size)))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "rUG6cDK1wZ8I"
+   },
+   "source": [
+    "## Hyperparameters"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "2QwbWknPwZ8M"
+   },
+   "outputs": [],
+   "source": [
+    "image_size = 128\n",
+    "train_path = \"/Users/sandhyakilari/Desktop/Fall Semester 2024/Computer Vision/Segmentation Project\"\n",
+    "epochs = 30\n",
+    "batch_size = 8\n",
+    "\n",
+    "## Training Ids\n",
+    "train_ids = os.listdir(os.path.join(train_path, \"images\"))\n",
+    "\n",
+    "\n",
+    "## Validation Data Size\n",
+    "val_data_size = 10\n",
+    "\n",
+    "valid_ids = train_ids[:val_data_size]\n",
+    "train_ids = train_ids[val_data_size:]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "6_cMh29vwZ8R",
+    "outputId": "534e3fce-2174-4f8d-99d3-c19550f24f70"
+   },
+   "outputs": [],
+   "source": [
+    "gen = DataGen(train_ids, train_path, batch_size=batch_size, image_size=image_size)\n",
+    "x, y = gen.__getitem__(0)\n",
+    "print(x.shape, y.shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 309
+    },
+    "id": "NCG2jWo9wZ8c",
+    "outputId": "f158147f-b3b7-4c7f-e0a0-13219d136aa7",
+    "scrolled": false
+   },
+   "outputs": [],
+   "source": [
+    "r = random.randint(0, len(x)-1)\n",
+    "\n",
+    "fig = plt.figure()\n",
+    "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
+    "ax = fig.add_subplot(1, 2, 1)\n",
+    "ax.imshow(x[r])\n",
+    "ax = fig.add_subplot(1, 2, 2)\n",
+    "ax.imshow(np.reshape(y[r], (image_size, image_size)), cmap=\"gray\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "1xJMvRqswZ8o"
+   },
+   "source": [
+    "## Different Convolutional Blocks"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "xtA9Rr75wZ8p"
+   },
+   "outputs": [],
+   "source": [
+    "def down_block(x, filters, kernel_size=(3, 3), padding=\"same\", strides=1):\n",
+    "    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(x)\n",
+    "    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(c)\n",
+    "    p = keras.layers.MaxPool2D((2, 2), (2, 2))(c)\n",
+    "    return c, p\n",
+    "\n",
+    "def up_block(x, skip, filters, kernel_size=(3, 3), padding=\"same\", strides=1):\n",
+    "    us = keras.layers.UpSampling2D((2, 2))(x)\n",
+    "    concat = keras.layers.Concatenate()([us, skip])\n",
+    "    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(concat)\n",
+    "    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(c)\n",
+    "    return c\n",
+    "\n",
+    "def bottleneck(x, filters, kernel_size=(3, 3), padding=\"same\", strides=1):\n",
+    "    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(x)\n",
+    "    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(c)\n",
+    "    return c"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "QMdNQ_HVwZ82"
+   },
+   "source": [
+    "## UNet Model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "YAa1RbB8wZ89"
+   },
+   "outputs": [],
+   "source": [
+    "def UNet():\n",
+    "    f = [16, 32, 64, 128, 256]\n",
+    "    inputs = keras.layers.Input((image_size, image_size, 3))\n",
+    "\n",
+    "    p0 = inputs\n",
+    "    c1, p1 = down_block(p0, f[0])  # 128 --> 64\n",
+    "    c2, p2 = down_block(p1, f[1])  # 64  --> 32\n",
+    "    c3, p3 = down_block(p2, f[2])  # 32  --> 16\n",
+    "    c4, p4 = down_block(p3, f[3])  # 16  --> 8\n",
+    "\n",
+    "    bn = bottleneck(p4, f[4])\n",
+    "\n",
+    "    u1 = up_block(bn, c4, f[3])  # 8  --> 16\n",
+    "    u2 = up_block(u1, c3, f[2])  # 16 --> 32\n",
+    "    u3 = up_block(u2, c2, f[1])  # 32 --> 64\n",
+    "    u4 = up_block(u3, c1, f[0])  # 64 --> 128\n",
+    "\n",
+    "    outputs = keras.layers.Conv2D(1, (1, 1), padding=\"same\", activation=\"sigmoid\")(u4)\n",
+    "    model = keras.models.Model(inputs, outputs)\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 1000
+    },
+    "id": "ntBPY06awZ9E",
+    "outputId": "bc8241cf-1ca4-46a5-ddf1-afcb43ddc395",
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "model = UNet()\n",
+    "model.compile(optimizer=\"adam\", loss=\"binary_crossentropy\", metrics=[\"acc\"])\n",
+    "model.summary()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "Nmkxq4-BwZ9Q"
+   },
+   "source": [
+    "## Training the model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "6I34e4tcwZ9T",
+    "outputId": "427bd73a-0c95-4b03-ce95-6584de960233",
+    "scrolled": false
+   },
+   "outputs": [],
+   "source": [
+    "train_gen = DataGen(train_ids, train_path, image_size=image_size, batch_size=batch_size)\n",
+    "valid_gen = DataGen(valid_ids, train_path, image_size=image_size, batch_size=batch_size)\n",
+    "\n",
+    "train_steps = len(train_ids)//batch_size\n",
+    "valid_steps = len(valid_ids)//batch_size\n",
+    "\n",
+    "history = model.fit(train_gen, validation_data=valid_gen, steps_per_epoch=train_steps,\n",
+    "          validation_steps=valid_steps, epochs=epochs)\n",
+    "\n",
+    "# Save the weights\n",
+    "model.save_weights(\"UNetW.weights.h5\")\n",
+    "\n",
+    "# Plot training & validation accuracy and loss\n",
+    "acc = history.history['acc']\n",
+    "val_acc = history.history['val_acc']\n",
+    "loss = history.history['loss']\n",
+    "val_loss = history.history['val_loss']\n",
+    "num_epochs = min(len(acc), len(val_acc))\n",
+    "epochs_range = range(1, num_epochs + 1)\n",
+    "\n",
+    "plt.figure(figsize=(16, 6))\n",
+    "\n",
+    "# Accuracy plot\n",
+    "plt.subplot(1, 2, 1)\n",
+    "plt.plot(epochs_range, acc[:num_epochs], label='Training Accuracy')\n",
+    "plt.plot(epochs_range, val_acc[:num_epochs], label='Validation Accuracy')\n",
+    "plt.title('Training & Validation Accuracy')\n",
+    "plt.xlabel('Epoch')\n",
+    "plt.ylabel('Accuracy')\n",
+    "plt.legend()\n",
+    "\n",
+    "# Loss plot\n",
+    "plt.subplot(1, 2, 2)\n",
+    "plt.plot(epochs_range, loss[:num_epochs], label='Training Loss')\n",
+    "plt.plot(epochs_range, val_loss[:num_epochs], label='Validation Loss')\n",
+    "plt.title('Training & Validation Loss')\n",
+    "plt.xlabel('Epoch')\n",
+    "plt.ylabel('Loss')\n",
+    "plt.legend()\n",
+    "\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "iDUykT_awZ9h"
+   },
+   "source": [
+    "## Testing the model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "i_oY4H3LwZ9l",
+    "outputId": "371b3f40-25d0-416c-fd0f-e3eb6b0df943"
+   },
+   "outputs": [],
+   "source": [
+    "## Save the Weights\n",
+    "model.save_weights(\"UNetW.weights.h5\")\n",
+    "\n",
+    "## Dataset for prediction\n",
+    "x, y = valid_gen.__getitem__(2)\n",
+    "result = model.predict(x)\n",
+    "\n",
+    "result = result > 0.5"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "DLdgSxZKpQcQ"
+   },
+   "outputs": [],
+   "source": [
+    "\n",
+    "# Load the weights\n",
+    "model.load_weights(\"UNetW.weights.h5\")\n",
+    "\n",
+    "# Now you can use the model for predictions or further training\n",
+    "x, y = valid_gen.__getitem__(2)\n",
+    "result = model.predict(x)\n",
+    "result = result > 0.5"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 309
+    },
+    "id": "uSdH42qVwZ9q",
+    "outputId": "95c65b4e-5f31-4a9e-88a1-d416166ec239"
+   },
+   "outputs": [],
+   "source": [
+    "fig = plt.figure()\n",
+    "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
+    "\n",
+    "ax = fig.add_subplot(1, 2, 1)\n",
+    "ax.imshow(np.reshape(y[0]*255, (image_size, image_size)), cmap=\"gray\")\n",
+    "\n",
+    "ax = fig.add_subplot(1, 2, 2)\n",
+    "ax.imshow(np.reshape(result[0]*255, (image_size, image_size)), cmap=\"gray\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 309
+    },
+    "id": "1OKwJGbVwZ-A",
+    "outputId": "8ef32166-359f-4023-f3a4-f12aeb895db0"
+   },
+   "outputs": [],
+   "source": [
+    "fig = plt.figure()\n",
+    "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
+    "\n",
+    "ax = fig.add_subplot(1, 2, 1)\n",
+    "ax.imshow(np.reshape(y[1]*255, (image_size, image_size)), cmap=\"gray\")\n",
+    "\n",
+    "ax = fig.add_subplot(1, 2, 2)\n",
+    "ax.imshow(np.reshape(result[1]*255, (image_size, image_size)), cmap=\"gray\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# CNN Model (Madhurya)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 1000
+    },
+    "id": "xBk2k9vzKRD5",
+    "outputId": "e9880f4e-a1a6-4c3b-aac7-bacca57154e5"
+   },
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import cv2\n",
+    "import numpy as np\n",
+    "import matplotlib.pyplot as plt\n",
+    "\n",
+    "import keras\n",
+    "from keras.models import Sequential\n",
+    "from keras.layers import Dense, Activation, Flatten, Reshape\n",
+    "from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D\n",
+    "from keras import regularizers\n",
+    "\n",
+    "# Set your data directory\n",
+    "data_dir = '/Users/sandhyakilari/Desktop/Fall Semester 2024/Computer Vision/Segmentation Project'\n",
+    "image_dir = os.path.join(data_dir, \"images\")\n",
+    "mask_dir = os.path.join(data_dir, \"masks\")\n",
+    "\n",
+    "# List files\n",
+    "image_names = sorted(os.listdir(image_dir))\n",
+    "mask_names = sorted(os.listdir(mask_dir))\n",
+    "\n",
+    "image_names = [name for name in image_names if name in mask_names]\n",
+    "\n",
+    "# Make sure that image_names and mask_names correspond one-to-one\n",
+    "# If they differ, ensure file naming consistency before running.\n",
+    "assert len(image_names) == len(mask_names), \"Number of images and masks do not match.\"\n",
+    "\n",
+    "# Desired shapes\n",
+    "input_image_shape = (64, 64)  # For the CNN input\n",
+    "mask_shape = (32, 32)         # For the CNN output\n",
+    "\n",
+    "X = []\n",
+    "Y = []\n",
+    "\n",
+    "for img_name, msk_name in zip(image_names, mask_names):\n",
+    "    img_path = os.path.join(image_dir, img_name)\n",
+    "    msk_path = os.path.join(mask_dir, msk_name)\n",
+    "\n",
+    "    # Read images in grayscale\n",
+    "    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)\n",
+    "    msk = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)\n",
+    "\n",
+    "    # Resize to desired shapes\n",
+    "    img = cv2.resize(img, input_image_shape)\n",
+    "    msk = cv2.resize(msk, mask_shape)\n",
+    "\n",
+    "    # Normalize (if you want to normalize)\n",
+    "    img = img.astype(np.float32) / 255.0\n",
+    "    msk = msk.astype(np.float32) / 255.0\n",
+    "\n",
+    "    # Add channel dimension\n",
+    "    img = np.expand_dims(img, axis=-1)  # (64,64,1)\n",
+    "    msk = np.expand_dims(msk, axis=-1)  # (32,32,1)\n",
+    "\n",
+    "    X.append(img)\n",
+    "    Y.append(msk)\n",
+    "\n",
+    "X = np.array(X)\n",
+    "Y = np.array(Y)\n",
+    "\n",
+    "print('Dataset shape :', X.shape, Y.shape)  # X: (N,64,64,1), Y: (N,32,32,1)\n",
+    "\n",
+    "# Create the CNN model\n",
+    "def create_model(input_shape=(64, 64, 1)):\n",
+    "    \"\"\"\n",
+    "    Simple convnet model: one convolution, one average pooling and one fully connected layer\n",
+    "    ending with a reshape to (32,32,1).\n",
+    "    \"\"\"\n",
+    "    model = Sequential()\n",
+    "    # Conv layer\n",
+    "    model.add(Conv2D(100, (11,11), padding='valid', strides=(1, 1), input_shape=input_shape))\n",
+    "    # Average Pooling\n",
+    "    model.add(AveragePooling2D((6,6)))\n",
+    "    # Flatten/Reshape step\n",
+    "    # After Conv+Pool:\n",
+    "    # Input: (64x64x1) -> Conv(11x11): (54x54x100) -> AvgPool(6x6): (9x9x100) = 8100 features\n",
+    "    model.add(Reshape((8100,)))  # Flatten to (8100,)\n",
+    "    model.add(Dense(1024, activation='sigmoid', kernel_regularizer=regularizers.l2(0.0001)))\n",
+    "    # Now we have (1024,). We want (32,32,1):\n",
+    "    # 32*32 = 1024, so we reshape to (32,32,1)\n",
+    "    model.add(Reshape((32,32,1)))\n",
+    "    return model\n",
+    "\n",
+    "m = create_model()\n",
+    "m.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])\n",
+    "print('Model Summary:')\n",
+    "m.summary()\n",
+    "\n",
+    "# Train the model\n",
+    "epochs = 20\n",
+    "batch_size = 16\n",
+    "history = m.fit(X, Y, batch_size=batch_size, epochs=epochs, validation_split=0.2)\n",
+    "\n",
+    "# Plot training and validation loss\n",
+    "plt.figure(figsize=(10,5))\n",
+    "plt.plot(history.history['loss'], label='Training Loss')\n",
+    "plt.plot(history.history['val_loss'], label='Validation Loss', linestyle='--')\n",
+    "plt.title(\"Learning Curve\")\n",
+    "plt.xlabel(\"Epochs\")\n",
+    "plt.ylabel(\"Loss\")\n",
+    "plt.legend()\n",
+    "plt.show()\n",
+    "\n",
+    "# Save the weights\n",
+    "m.save_weights(\"UNetW.weights.h5\")\n",
+    "\n",
+    "# Example prediction\n",
+    "y_pred = m.predict(X, batch_size=batch_size)\n",
+    "print(\"y_pred shape:\", y_pred.shape)\n",
+    "\n",
+    "# Visualize a sample\n",
+    "idx = 0\n",
+    "fig, ax = plt.subplots(1, 2, figsize=(8,4))\n",
+    "ax[0].imshow(X[idx].reshape(64,64), cmap='gray')\n",
+    "ax[0].set_title('Input Image')\n",
+    "ax[1].imshow(y_pred[idx].reshape(32,32), cmap='gray')\n",
+    "ax[1].set_title('Predicted Mask')\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# AttentionUNET (Vishal)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import tensorflow as tf\n",
+    "from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Activation, BatchNormalization, Add, Multiply\n",
+    "from tensorflow.keras.models import Model\n",
+    "import os\n",
+    "import numpy as np\n",
+    "from tensorflow.keras.preprocessing.image import load_img, img_to_array\n",
+    "\n",
+    "def attention_block(x, g, inter_channel):\n",
+    "    \"\"\"\n",
+    "    Attention Block: Refines encoder features based on decoder signals.\n",
+    "    x: Input tensor from the encoder (skip connection)\n",
+    "    g: Gating signal from the decoder (upsampled tensor)\n",
+    "    inter_channel: Number of intermediate channels (reduces computation)\n",
+    "    \"\"\"\n",
+    "    # 1x1 Convolution on input tensor\n",
+    "    theta_x = Conv2D(inter_channel, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)\n",
+    "    # 1x1 Convolution on gating tensor\n",
+    "    phi_g = Conv2D(inter_channel, kernel_size=(1, 1), strides=(1, 1), padding='same')(g)\n",
+    "    \n",
+    "    # Add the transformed inputs and apply ReLU\n",
+    "    add_xg = Add()([theta_x, phi_g])\n",
+    "    relu_xg = Activation('relu')(add_xg)\n",
+    "    \n",
+    "    # Another 1x1 Convolution to generate attention coefficients\n",
+    "    psi = Conv2D(1, kernel_size=(1, 1), strides=(1, 1), padding='same')(relu_xg)\n",
+    "    # Sigmoid activation to normalize attention weights\n",
+    "    sigmoid_psi = Activation('sigmoid')(psi)\n",
+    "    \n",
+    "    # Multiply the input tensor with the attention weights\n",
+    "    return Multiply()([x, sigmoid_psi])\n",
+    "\n",
+    "def conv_block(x, filters):\n",
+    "    \"\"\"\n",
+    "    Convolutional Block: Apply two 3x3 convolutions followed by BatchNorm and ReLU.\n",
+    "    x: Input tensor\n",
+    "    filters: Number of output filters for the convolutions\n",
+    "    \"\"\"\n",
+    "    x = Conv2D(filters, kernel_size=(3, 3), padding='same')(x)\n",
+    "    x = BatchNormalization()(x)\n",
+    "    x = Activation('relu')(x)\n",
+    "    x = Conv2D(filters, kernel_size=(3, 3), padding='same')(x)\n",
+    "    x = BatchNormalization()(x)\n",
+    "    x = Activation('relu')(x)\n",
+    "    return x\n",
+    "\n",
+    "def attention_unet(input_shape, num_classes):\n",
+    "    \"\"\"\n",
+    "    Attention U-Net model architecture.\n",
+    "    input_shape: Shape of input images (H, W, C)\n",
+    "    num_classes: Number of output segmentation classes\n",
+    "    \"\"\"\n",
+    "    # Input layer for the images\n",
+    "    inputs = Input(input_shape)\n",
+    "    \n",
+    "    # Encoder (Downsampling path)\n",
+    "    c1 = conv_block(inputs, 64)              # First Conv Block\n",
+    "    p1 = MaxPooling2D((2, 2))(c1)            # Downsample by 2\n",
+    "    \n",
+    "    c2 = conv_block(p1, 128)                 # Second Conv Block\n",
+    "    p2 = MaxPooling2D((2, 2))(c2)            # Downsample by 2\n",
+    "    \n",
+    "    c3 = conv_block(p2, 256)                 # Third Conv Block\n",
+    "    p3 = MaxPooling2D((2, 2))(c3)            # Downsample by 2\n",
+    "    \n",
+    "    c4 = conv_block(p3, 512)                 # Fourth Conv Block\n",
+    "    p4 = MaxPooling2D((2, 2))(c4)            # Downsample by 2\n",
+    "    \n",
+    "    # Bottleneck (lowest level of the U-Net)\n",
+    "    c5 = conv_block(p4, 1024)\n",
+    "    \n",
+    "    # Decoder (Upsampling path)\n",
+    "    up6 = UpSampling2D((2, 2))(c5)           # Upsample\n",
+    "    att6 = attention_block(c4, up6, 512)     # Attention Block\n",
+    "    merge6 = concatenate([up6, att6], axis=-1)  # Concatenate features\n",
+    "    c6 = conv_block(merge6, 512)             # Conv Block after concatenation\n",
+    "    \n",
+    "    up7 = UpSampling2D((2, 2))(c6)\n",
+    "    att7 = attention_block(c3, up7, 256)\n",
+    "    merge7 = concatenate([up7, att7], axis=-1)\n",
+    "    c7 = conv_block(merge7, 256)\n",
+    "    \n",
+    "    up8 = UpSampling2D((2, 2))(c7)\n",
+    "    att8 = attention_block(c2, up8, 128)\n",
+    "    merge8 = concatenate([up8, att8], axis=-1)\n",
+    "    c8 = conv_block(merge8, 128)\n",
+    "    \n",
+    "    up9 = UpSampling2D((2, 2))(c8)\n",
+    "    att9 = attention_block(c1, up9, 64)\n",
+    "    merge9 = concatenate([up9, att9], axis=-1)\n",
+    "    c9 = conv_block(merge9, 64)\n",
+    "    \n",
+    "    # Output layer for segmentation\n",
+    "    outputs = Conv2D(num_classes, (1, 1), activation='softmax' if num_classes > 1 else 'sigmoid')(c9)\n",
+    "    \n",
+    "    # Define the model\n",
+    "    model = Model(inputs=inputs, outputs=outputs)\n",
+    "    return model\n",
+    "\n",
+    "# Function to load and preprocess images and masks\n",
+    "def load_data(image_dir, mask_dir, image_size):\n",
+    "    \"\"\"\n",
+    "    Load and preprocess images and masks for training.\n",
+    "    image_dir: Path to the directory containing input images\n",
+    "    mask_dir: Path to the directory containing segmentation masks\n",
+    "    image_size: Tuple specifying the size (height, width) to resize the images and masks\n",
+    "    \"\"\"\n",
+    "    images = []\n",
+    "    masks = []\n",
+    "    image_files = sorted(os.listdir(image_dir))\n",
+    "    mask_files = sorted(os.listdir(mask_dir))\n",
+    "    \n",
+    "    for img_file, mask_file in zip(image_files, mask_files):\n",
+    "        try:\n",
+    "            # Load and preprocess images\n",
+    "            img_path = os.path.join(image_dir, img_file)\n",
+    "            mask_path = os.path.join(mask_dir, mask_file)\n",
+    "            \n",
+    "            img = load_img(img_path, target_size=image_size)  # Resize image\n",
+    "            mask = load_img(mask_path, target_size=image_size, color_mode='grayscale')  # Resize mask\n",
+    "            \n",
+    "            # Convert to numpy arrays and normalize\n",
+    "            img = img_to_array(img) / 255.0\n",
+    "            mask = img_to_array(mask) / 255.0\n",
+    "            mask = np.round(mask)  # Ensure masks are binary\n",
+    "            \n",
+    "            images.append(img)\n",
+    "            masks.append(mask)\n",
+    "        except Exception as e:\n",
+    "            print(f\"Error loading {img_file} or {mask_file}: {e}. Skipping...\")\n",
+    "    \n",
+    "    return np.array(images), np.array(masks)\n",
+    "\n",
+    "# Example usage\n",
+    "if __name__ == \"__main__\":\n",
+    "    # Load data\n",
+    "    image_dir = \"./images/\"  # Replace with your image directory\n",
+    "    mask_dir = \"./masks/\"    # Replace with your mask directory\n",
+    "    image_size = (128, 128)       # Resize all images to 128x128\n",
+    "    images, masks = load_data(image_dir, mask_dir, image_size)\n",
+    "    \n",
+    "    # Define the model\n",
+    "    model = attention_unet(input_shape=(128, 128, 3), num_classes=1)\n",
+    "    \n",
+    "    # Compile the model\n",
+    "    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n",
+    "    \n",
+    "    # Train the model\n",
+    "    model.fit(images, masks, batch_size=8, epochs=20, validation_split=0.1)"
+   ]
+  }
+ ],
+ "metadata": {
+  "accelerator": "GPU",
+  "colab": {
+   "provenance": []
+  },
+  "kernelspec": {
+   "display_name": "base",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.12.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}