781 lines (780 with data), 26.2 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "nDXW7A7pwZ7a"
},
"source": [
"# UNET SEGMENTATION (SUJAL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M5oF6qqZwZ70"
},
"outputs": [],
"source": [
"## Imports\n",
"import os\n",
"import sys\n",
"import random\n",
"\n",
"import numpy as np\n",
"import cv2\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"\n",
"## Seeding\n",
"seed = 42\n",
"random.seed = seed\n",
"np.random.seed = seed\n",
"tf.seed = seed"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "70jPovBLwZ7-"
},
"source": [
"## Data "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SvRbQp8LwZ8B"
},
"outputs": [],
"source": [
" class DataGen(keras.utils.Sequence):\n",
" def __init__(self, ids, path, batch_size=8, image_size=128):\n",
" self.ids = ids\n",
" self.path = path\n",
" self.batch_size = batch_size\n",
" self.image_size = image_size\n",
" self.on_epoch_end()\n",
"\n",
" def __load__(self, id_name):\n",
" # Construct image path\n",
" image_path = os.path.join(self.path, \"images\", id_name)\n",
"\n",
" # Construct mask path\n",
" mask_path = os.path.join(self.path, \"masks\", id_name)\n",
"\n",
" # Read and validate image\n",
" image = cv2.imread(image_path, 1)\n",
" if image is None:\n",
" print(f\"Warning: Failed to load image: {image_path}\")\n",
" return None, None\n",
"\n",
" image = cv2.resize(image, (self.image_size, self.image_size))\n",
"\n",
" # Read and validate mask\n",
" mask = cv2.imread(mask_path, 0) # mask often is a single-channel image\n",
" if mask is None:\n",
" print(f\"Warning: Failed to load mask: {mask_path}\")\n",
" return None, None\n",
"\n",
" mask = cv2.resize(mask, (self.image_size, self.image_size))\n",
" mask = np.expand_dims(mask, axis=-1)\n",
"\n",
" # Normalize\n",
" image = image / 255.0\n",
" mask = mask / 255.0\n",
"\n",
" return image, mask\n",
"\n",
" def __getitem__(self, index):\n",
" if (index + 1) * self.batch_size > len(self.ids):\n",
" self.batch_size = len(self.ids) - index * self.batch_size\n",
"\n",
" files_batch = self.ids[index * self.batch_size: (index + 1) * self.batch_size]\n",
"\n",
" image = []\n",
" mask = []\n",
"\n",
" for id_name in files_batch:\n",
" _img, _mask = self.__load__(id_name)\n",
" if _img is not None and _mask is not None: # Only add valid samples\n",
" image.append(_img)\n",
" mask.append(_mask)\n",
"\n",
" if len(image) == 0: # Handle case where all files in the batch fail\n",
" raise ValueError(\"No valid images or masks found in the current batch.\")\n",
"\n",
" image = np.array(image)\n",
" mask = np.array(mask)\n",
"\n",
" return image, mask\n",
"\n",
" def on_epoch_end(self):\n",
" pass\n",
"\n",
" def __len__(self):\n",
" return int(np.ceil(len(self.ids) / float(self.batch_size)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rUG6cDK1wZ8I"
},
"source": [
"## Hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2QwbWknPwZ8M"
},
"outputs": [],
"source": [
"image_size = 128\n",
"train_path = \"/Users/sandhyakilari/Desktop/Fall Semester 2024/Computer Vision/Segmentation Project\"\n",
"epochs = 30\n",
"batch_size = 8\n",
"\n",
"## Training Ids\n",
"train_ids = os.listdir(os.path.join(train_path, \"images\"))\n",
"\n",
"\n",
"## Validation Data Size\n",
"val_data_size = 10\n",
"\n",
"valid_ids = train_ids[:val_data_size]\n",
"train_ids = train_ids[val_data_size:]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6_cMh29vwZ8R",
"outputId": "534e3fce-2174-4f8d-99d3-c19550f24f70"
},
"outputs": [],
"source": [
"gen = DataGen(train_ids, train_path, batch_size=batch_size, image_size=image_size)\n",
"x, y = gen.__getitem__(0)\n",
"print(x.shape, y.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 309
},
"id": "NCG2jWo9wZ8c",
"outputId": "f158147f-b3b7-4c7f-e0a0-13219d136aa7",
"scrolled": false
},
"outputs": [],
"source": [
"r = random.randint(0, len(x)-1)\n",
"\n",
"fig = plt.figure()\n",
"fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
"ax = fig.add_subplot(1, 2, 1)\n",
"ax.imshow(x[r])\n",
"ax = fig.add_subplot(1, 2, 2)\n",
"ax.imshow(np.reshape(y[r], (image_size, image_size)), cmap=\"gray\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1xJMvRqswZ8o"
},
"source": [
"## Different Convolutional Blocks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xtA9Rr75wZ8p"
},
"outputs": [],
"source": [
"def down_block(x, filters, kernel_size=(3, 3), padding=\"same\", strides=1):\n",
" c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(x)\n",
" c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(c)\n",
" p = keras.layers.MaxPool2D((2, 2), (2, 2))(c)\n",
" return c, p\n",
"\n",
"def up_block(x, skip, filters, kernel_size=(3, 3), padding=\"same\", strides=1):\n",
" us = keras.layers.UpSampling2D((2, 2))(x)\n",
" concat = keras.layers.Concatenate()([us, skip])\n",
" c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(concat)\n",
" c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(c)\n",
" return c\n",
"\n",
"def bottleneck(x, filters, kernel_size=(3, 3), padding=\"same\", strides=1):\n",
" c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(x)\n",
" c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(c)\n",
" return c"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QMdNQ_HVwZ82"
},
"source": [
"## UNet Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YAa1RbB8wZ89"
},
"outputs": [],
"source": [
"def UNet():\n",
" f = [16, 32, 64, 128, 256]\n",
" inputs = keras.layers.Input((image_size, image_size, 3))\n",
"\n",
" p0 = inputs\n",
" c1, p1 = down_block(p0, f[0]) # 128 --> 64\n",
" c2, p2 = down_block(p1, f[1]) # 64 --> 32\n",
" c3, p3 = down_block(p2, f[2]) # 32 --> 16\n",
" c4, p4 = down_block(p3, f[3]) # 16 --> 8\n",
"\n",
" bn = bottleneck(p4, f[4])\n",
"\n",
" u1 = up_block(bn, c4, f[3]) # 8 --> 16\n",
" u2 = up_block(u1, c3, f[2]) # 16 --> 32\n",
" u3 = up_block(u2, c2, f[1]) # 32 --> 64\n",
" u4 = up_block(u3, c1, f[0]) # 64 --> 128\n",
"\n",
" outputs = keras.layers.Conv2D(1, (1, 1), padding=\"same\", activation=\"sigmoid\")(u4)\n",
" model = keras.models.Model(inputs, outputs)\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "ntBPY06awZ9E",
"outputId": "bc8241cf-1ca4-46a5-ddf1-afcb43ddc395",
"scrolled": true
},
"outputs": [],
"source": [
"model = UNet()\n",
"model.compile(optimizer=\"adam\", loss=\"binary_crossentropy\", metrics=[\"acc\"])\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nmkxq4-BwZ9Q"
},
"source": [
"## Training the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6I34e4tcwZ9T",
"outputId": "427bd73a-0c95-4b03-ce95-6584de960233",
"scrolled": false
},
"outputs": [],
"source": [
"train_gen = DataGen(train_ids, train_path, image_size=image_size, batch_size=batch_size)\n",
"valid_gen = DataGen(valid_ids, train_path, image_size=image_size, batch_size=batch_size)\n",
"\n",
"train_steps = len(train_ids)//batch_size\n",
"valid_steps = len(valid_ids)//batch_size\n",
"\n",
"history = model.fit(train_gen, validation_data=valid_gen, steps_per_epoch=train_steps,\n",
" validation_steps=valid_steps, epochs=epochs)\n",
"\n",
"# Save the weights\n",
"model.save_weights(\"UNetW.weights.h5\")\n",
"\n",
"# Plot training & validation accuracy and loss\n",
"acc = history.history['acc']\n",
"val_acc = history.history['val_acc']\n",
"loss = history.history['loss']\n",
"val_loss = history.history['val_loss']\n",
"num_epochs = min(len(acc), len(val_acc))\n",
"epochs_range = range(1, num_epochs + 1)\n",
"\n",
"plt.figure(figsize=(16, 6))\n",
"\n",
"# Accuracy plot\n",
"plt.subplot(1, 2, 1)\n",
"plt.plot(epochs_range, acc[:num_epochs], label='Training Accuracy')\n",
"plt.plot(epochs_range, val_acc[:num_epochs], label='Validation Accuracy')\n",
"plt.title('Training & Validation Accuracy')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Accuracy')\n",
"plt.legend()\n",
"\n",
"# Loss plot\n",
"plt.subplot(1, 2, 2)\n",
"plt.plot(epochs_range, loss[:num_epochs], label='Training Loss')\n",
"plt.plot(epochs_range, val_loss[:num_epochs], label='Validation Loss')\n",
"plt.title('Training & Validation Loss')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Loss')\n",
"plt.legend()\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iDUykT_awZ9h"
},
"source": [
"## Testing the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "i_oY4H3LwZ9l",
"outputId": "371b3f40-25d0-416c-fd0f-e3eb6b0df943"
},
"outputs": [],
"source": [
"## Save the Weights\n",
"model.save_weights(\"UNetW.weights.h5\")\n",
"\n",
"## Dataset for prediction\n",
"x, y = valid_gen.__getitem__(2)\n",
"result = model.predict(x)\n",
"\n",
"result = result > 0.5"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DLdgSxZKpQcQ"
},
"outputs": [],
"source": [
"\n",
"# Load the weights\n",
"model.load_weights(\"UNetW.weights.h5\")\n",
"\n",
"# Now you can use the model for predictions or further training\n",
"x, y = valid_gen.__getitem__(2)\n",
"result = model.predict(x)\n",
"result = result > 0.5"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 309
},
"id": "uSdH42qVwZ9q",
"outputId": "95c65b4e-5f31-4a9e-88a1-d416166ec239"
},
"outputs": [],
"source": [
"fig = plt.figure()\n",
"fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
"\n",
"ax = fig.add_subplot(1, 2, 1)\n",
"ax.imshow(np.reshape(y[0]*255, (image_size, image_size)), cmap=\"gray\")\n",
"\n",
"ax = fig.add_subplot(1, 2, 2)\n",
"ax.imshow(np.reshape(result[0]*255, (image_size, image_size)), cmap=\"gray\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 309
},
"id": "1OKwJGbVwZ-A",
"outputId": "8ef32166-359f-4023-f3a4-f12aeb895db0"
},
"outputs": [],
"source": [
"fig = plt.figure()\n",
"fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
"\n",
"ax = fig.add_subplot(1, 2, 1)\n",
"ax.imshow(np.reshape(y[1]*255, (image_size, image_size)), cmap=\"gray\")\n",
"\n",
"ax = fig.add_subplot(1, 2, 2)\n",
"ax.imshow(np.reshape(result[1]*255, (image_size, image_size)), cmap=\"gray\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CNN Model (Madhurya)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "xBk2k9vzKRD5",
"outputId": "e9880f4e-a1a6-4c3b-aac7-bacca57154e5"
},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import keras\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Activation, Flatten, Reshape\n",
"from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D\n",
"from keras import regularizers\n",
"\n",
"# Set your data directory\n",
"data_dir = '/Users/sandhyakilari/Desktop/Fall Semester 2024/Computer Vision/Segmentation Project'\n",
"image_dir = os.path.join(data_dir, \"images\")\n",
"mask_dir = os.path.join(data_dir, \"masks\")\n",
"\n",
"# List files\n",
"image_names = sorted(os.listdir(image_dir))\n",
"mask_names = sorted(os.listdir(mask_dir))\n",
"\n",
"image_names = [name for name in image_names if name in mask_names]\n",
"\n",
"# Make sure that image_names and mask_names correspond one-to-one\n",
"# If they differ, ensure file naming consistency before running.\n",
"assert len(image_names) == len(mask_names), \"Number of images and masks do not match.\"\n",
"\n",
"# Desired shapes\n",
"input_image_shape = (64, 64) # For the CNN input\n",
"mask_shape = (32, 32) # For the CNN output\n",
"\n",
"X = []\n",
"Y = []\n",
"\n",
"for img_name, msk_name in zip(image_names, mask_names):\n",
" img_path = os.path.join(image_dir, img_name)\n",
" msk_path = os.path.join(mask_dir, msk_name)\n",
"\n",
" # Read images in grayscale\n",
" img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)\n",
" msk = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)\n",
"\n",
" # Resize to desired shapes\n",
" img = cv2.resize(img, input_image_shape)\n",
" msk = cv2.resize(msk, mask_shape)\n",
"\n",
" # Normalize (if you want to normalize)\n",
" img = img.astype(np.float32) / 255.0\n",
" msk = msk.astype(np.float32) / 255.0\n",
"\n",
" # Add channel dimension\n",
" img = np.expand_dims(img, axis=-1) # (64,64,1)\n",
" msk = np.expand_dims(msk, axis=-1) # (32,32,1)\n",
"\n",
" X.append(img)\n",
" Y.append(msk)\n",
"\n",
"X = np.array(X)\n",
"Y = np.array(Y)\n",
"\n",
"print('Dataset shape :', X.shape, Y.shape) # X: (N,64,64,1), Y: (N,32,32,1)\n",
"\n",
"# Create the CNN model\n",
"def create_model(input_shape=(64, 64, 1)):\n",
" \"\"\"\n",
" Simple convnet model: one convolution, one average pooling and one fully connected layer\n",
" ending with a reshape to (32,32,1).\n",
" \"\"\"\n",
" model = Sequential()\n",
" # Conv layer\n",
" model.add(Conv2D(100, (11,11), padding='valid', strides=(1, 1), input_shape=input_shape))\n",
" # Average Pooling\n",
" model.add(AveragePooling2D((6,6)))\n",
" # Flatten/Reshape step\n",
" # After Conv+Pool:\n",
" # Input: (64x64x1) -> Conv(11x11): (54x54x100) -> AvgPool(6x6): (9x9x100) = 8100 features\n",
" model.add(Reshape((8100,))) # Flatten to (8100,)\n",
" model.add(Dense(1024, activation='sigmoid', kernel_regularizer=regularizers.l2(0.0001)))\n",
" # Now we have (1024,). We want (32,32,1):\n",
" # 32*32 = 1024, so we reshape to (32,32,1)\n",
" model.add(Reshape((32,32,1)))\n",
" return model\n",
"\n",
"m = create_model()\n",
"m.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])\n",
"print('Model Summary:')\n",
"m.summary()\n",
"\n",
"# Train the model\n",
"epochs = 20\n",
"batch_size = 16\n",
"history = m.fit(X, Y, batch_size=batch_size, epochs=epochs, validation_split=0.2)\n",
"\n",
"# Plot training and validation loss\n",
"plt.figure(figsize=(10,5))\n",
"plt.plot(history.history['loss'], label='Training Loss')\n",
"plt.plot(history.history['val_loss'], label='Validation Loss', linestyle='--')\n",
"plt.title(\"Learning Curve\")\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.legend()\n",
"plt.show()\n",
"\n",
"# Save the weights\n",
"m.save_weights(\"UNetW.weights.h5\")\n",
"\n",
"# Example prediction\n",
"y_pred = m.predict(X, batch_size=batch_size)\n",
"print(\"y_pred shape:\", y_pred.shape)\n",
"\n",
"# Visualize a sample\n",
"idx = 0\n",
"fig, ax = plt.subplots(1, 2, figsize=(8,4))\n",
"ax[0].imshow(X[idx].reshape(64,64), cmap='gray')\n",
"ax[0].set_title('Input Image')\n",
"ax[1].imshow(y_pred[idx].reshape(32,32), cmap='gray')\n",
"ax[1].set_title('Predicted Mask')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AttentionUNET (Vishal)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Activation, BatchNormalization, Add, Multiply\n",
"from tensorflow.keras.models import Model\n",
"import os\n",
"import numpy as np\n",
"from tensorflow.keras.preprocessing.image import load_img, img_to_array\n",
"\n",
"def attention_block(x, g, inter_channel):\n",
" \"\"\"\n",
" Attention Block: Refines encoder features based on decoder signals.\n",
" x: Input tensor from the encoder (skip connection)\n",
" g: Gating signal from the decoder (upsampled tensor)\n",
" inter_channel: Number of intermediate channels (reduces computation)\n",
" \"\"\"\n",
" # 1x1 Convolution on input tensor\n",
" theta_x = Conv2D(inter_channel, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)\n",
" # 1x1 Convolution on gating tensor\n",
" phi_g = Conv2D(inter_channel, kernel_size=(1, 1), strides=(1, 1), padding='same')(g)\n",
" \n",
" # Add the transformed inputs and apply ReLU\n",
" add_xg = Add()([theta_x, phi_g])\n",
" relu_xg = Activation('relu')(add_xg)\n",
" \n",
" # Another 1x1 Convolution to generate attention coefficients\n",
" psi = Conv2D(1, kernel_size=(1, 1), strides=(1, 1), padding='same')(relu_xg)\n",
" # Sigmoid activation to normalize attention weights\n",
" sigmoid_psi = Activation('sigmoid')(psi)\n",
" \n",
" # Multiply the input tensor with the attention weights\n",
" return Multiply()([x, sigmoid_psi])\n",
"\n",
"def conv_block(x, filters):\n",
" \"\"\"\n",
" Convolutional Block: Apply two 3x3 convolutions followed by BatchNorm and ReLU.\n",
" x: Input tensor\n",
" filters: Number of output filters for the convolutions\n",
" \"\"\"\n",
" x = Conv2D(filters, kernel_size=(3, 3), padding='same')(x)\n",
" x = BatchNormalization()(x)\n",
" x = Activation('relu')(x)\n",
" x = Conv2D(filters, kernel_size=(3, 3), padding='same')(x)\n",
" x = BatchNormalization()(x)\n",
" x = Activation('relu')(x)\n",
" return x\n",
"\n",
"def attention_unet(input_shape, num_classes):\n",
" \"\"\"\n",
" Attention U-Net model architecture.\n",
" input_shape: Shape of input images (H, W, C)\n",
" num_classes: Number of output segmentation classes\n",
" \"\"\"\n",
" # Input layer for the images\n",
" inputs = Input(input_shape)\n",
" \n",
" # Encoder (Downsampling path)\n",
" c1 = conv_block(inputs, 64) # First Conv Block\n",
" p1 = MaxPooling2D((2, 2))(c1) # Downsample by 2\n",
" \n",
" c2 = conv_block(p1, 128) # Second Conv Block\n",
" p2 = MaxPooling2D((2, 2))(c2) # Downsample by 2\n",
" \n",
" c3 = conv_block(p2, 256) # Third Conv Block\n",
" p3 = MaxPooling2D((2, 2))(c3) # Downsample by 2\n",
" \n",
" c4 = conv_block(p3, 512) # Fourth Conv Block\n",
" p4 = MaxPooling2D((2, 2))(c4) # Downsample by 2\n",
" \n",
" # Bottleneck (lowest level of the U-Net)\n",
" c5 = conv_block(p4, 1024)\n",
" \n",
" # Decoder (Upsampling path)\n",
" up6 = UpSampling2D((2, 2))(c5) # Upsample\n",
" att6 = attention_block(c4, up6, 512) # Attention Block\n",
" merge6 = concatenate([up6, att6], axis=-1) # Concatenate features\n",
" c6 = conv_block(merge6, 512) # Conv Block after concatenation\n",
" \n",
" up7 = UpSampling2D((2, 2))(c6)\n",
" att7 = attention_block(c3, up7, 256)\n",
" merge7 = concatenate([up7, att7], axis=-1)\n",
" c7 = conv_block(merge7, 256)\n",
" \n",
" up8 = UpSampling2D((2, 2))(c7)\n",
" att8 = attention_block(c2, up8, 128)\n",
" merge8 = concatenate([up8, att8], axis=-1)\n",
" c8 = conv_block(merge8, 128)\n",
" \n",
" up9 = UpSampling2D((2, 2))(c8)\n",
" att9 = attention_block(c1, up9, 64)\n",
" merge9 = concatenate([up9, att9], axis=-1)\n",
" c9 = conv_block(merge9, 64)\n",
" \n",
" # Output layer for segmentation\n",
" outputs = Conv2D(num_classes, (1, 1), activation='softmax' if num_classes > 1 else 'sigmoid')(c9)\n",
" \n",
" # Define the model\n",
" model = Model(inputs=inputs, outputs=outputs)\n",
" return model\n",
"\n",
"# Function to load and preprocess images and masks\n",
"def load_data(image_dir, mask_dir, image_size):\n",
" \"\"\"\n",
" Load and preprocess images and masks for training.\n",
" image_dir: Path to the directory containing input images\n",
" mask_dir: Path to the directory containing segmentation masks\n",
" image_size: Tuple specifying the size (height, width) to resize the images and masks\n",
" \"\"\"\n",
" images = []\n",
" masks = []\n",
" image_files = sorted(os.listdir(image_dir))\n",
" mask_files = sorted(os.listdir(mask_dir))\n",
" \n",
" for img_file, mask_file in zip(image_files, mask_files):\n",
" try:\n",
" # Load and preprocess images\n",
" img_path = os.path.join(image_dir, img_file)\n",
" mask_path = os.path.join(mask_dir, mask_file)\n",
" \n",
" img = load_img(img_path, target_size=image_size) # Resize image\n",
" mask = load_img(mask_path, target_size=image_size, color_mode='grayscale') # Resize mask\n",
" \n",
" # Convert to numpy arrays and normalize\n",
" img = img_to_array(img) / 255.0\n",
" mask = img_to_array(mask) / 255.0\n",
" mask = np.round(mask) # Ensure masks are binary\n",
" \n",
" images.append(img)\n",
" masks.append(mask)\n",
" except Exception as e:\n",
" print(f\"Error loading {img_file} or {mask_file}: {e}. Skipping...\")\n",
" \n",
" return np.array(images), np.array(masks)\n",
"\n",
"# Example usage\n",
"if __name__ == \"__main__\":\n",
" # Load data\n",
" image_dir = \"./images/\" # Replace with your image directory\n",
" mask_dir = \"./masks/\" # Replace with your mask directory\n",
" image_size = (128, 128) # Resize all images to 128x128\n",
" images, masks = load_data(image_dir, mask_dir, image_size)\n",
" \n",
" # Define the model\n",
" model = attention_unet(input_shape=(128, 128, 3), num_classes=1)\n",
" \n",
" # Compile the model\n",
" model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n",
" \n",
" # Train the model\n",
" model.fit(images, masks, batch_size=8, epochs=20, validation_split=0.1)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}