Diff of /AttentonUnet.ipynb [000000] .. [c1eed3]

Switch to unified view

a b/AttentonUnet.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": 3,
6
   "metadata": {},
7
   "outputs": [
8
    {
9
     "name": "stdout",
10
     "output_type": "stream",
11
     "text": [
12
      "Error loading .DS_Store or 0655[0]_47.png: cannot identify image file <_io.BytesIO object at 0x35adee660>. Skipping...\n",
13
      "Epoch 1/20\n",
14
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m384s\u001b[0m 2s/step - accuracy: 0.9061 - loss: 0.2485 - val_accuracy: 0.8808 - val_loss: 0.3486\n",
15
      "Epoch 2/20\n",
16
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m384s\u001b[0m 2s/step - accuracy: 0.9415 - loss: 0.1394 - val_accuracy: 0.8412 - val_loss: 0.4048\n",
17
      "Epoch 3/20\n",
18
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m378s\u001b[0m 2s/step - accuracy: 0.9457 - loss: 0.1280 - val_accuracy: 0.8718 - val_loss: 0.4388\n",
19
      "Epoch 4/20\n",
20
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m385s\u001b[0m 2s/step - accuracy: 0.9491 - loss: 0.1193 - val_accuracy: 0.8620 - val_loss: 0.4341\n",
21
      "Epoch 5/20\n",
22
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m378s\u001b[0m 2s/step - accuracy: 0.9492 - loss: 0.1185 - val_accuracy: 0.8636 - val_loss: 0.5675\n",
23
      "Epoch 6/20\n",
24
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m384s\u001b[0m 2s/step - accuracy: 0.9515 - loss: 0.1134 - val_accuracy: 0.8706 - val_loss: 0.5460\n",
25
      "Epoch 7/20\n",
26
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m384s\u001b[0m 2s/step - accuracy: 0.9568 - loss: 0.0998 - val_accuracy: 0.8562 - val_loss: 0.6479\n",
27
      "Epoch 8/20\n",
28
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m382s\u001b[0m 2s/step - accuracy: 0.9572 - loss: 0.0983 - val_accuracy: 0.8637 - val_loss: 1.0583\n",
29
      "Epoch 9/20\n",
30
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m391s\u001b[0m 2s/step - accuracy: 0.9601 - loss: 0.0928 - val_accuracy: 0.8689 - val_loss: 0.4872\n",
31
      "Epoch 10/20\n",
32
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m385s\u001b[0m 2s/step - accuracy: 0.9616 - loss: 0.0885 - val_accuracy: 0.8676 - val_loss: 0.6407\n",
33
      "Epoch 11/20\n",
34
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m373s\u001b[0m 2s/step - accuracy: 0.9648 - loss: 0.0807 - val_accuracy: 0.8683 - val_loss: 0.6889\n",
35
      "Epoch 12/20\n",
36
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m377s\u001b[0m 2s/step - accuracy: 0.9663 - loss: 0.0786 - val_accuracy: 0.8550 - val_loss: 0.7435\n",
37
      "Epoch 13/20\n",
38
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m374s\u001b[0m 2s/step - accuracy: 0.9703 - loss: 0.0703 - val_accuracy: 0.8677 - val_loss: 0.6834\n",
39
      "Epoch 14/20\n",
40
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m373s\u001b[0m 2s/step - accuracy: 0.9712 - loss: 0.0665 - val_accuracy: 0.8694 - val_loss: 0.5149\n",
41
      "Epoch 15/20\n",
42
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m379s\u001b[0m 2s/step - accuracy: 0.9716 - loss: 0.0672 - val_accuracy: 0.8633 - val_loss: 0.7259\n",
43
      "Epoch 16/20\n",
44
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m395s\u001b[0m 2s/step - accuracy: 0.9748 - loss: 0.0594 - val_accuracy: 0.8736 - val_loss: 0.6896\n",
45
      "Epoch 17/20\n",
46
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m380s\u001b[0m 2s/step - accuracy: 0.9767 - loss: 0.0545 - val_accuracy: 0.8695 - val_loss: 0.7535\n",
47
      "Epoch 18/20\n",
48
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m389s\u001b[0m 2s/step - accuracy: 0.9773 - loss: 0.0532 - val_accuracy: 0.8664 - val_loss: 0.8831\n",
49
      "Epoch 19/20\n",
50
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m376s\u001b[0m 2s/step - accuracy: 0.9781 - loss: 0.0512 - val_accuracy: 0.8720 - val_loss: 0.7170\n",
51
      "Epoch 20/20\n",
52
      "\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m384s\u001b[0m 2s/step - accuracy: 0.9790 - loss: 0.0487 - val_accuracy: 0.8707 - val_loss: 0.6628\n"
53
     ]
54
    }
55
   ],
56
   "source": [
57
    "import tensorflow as tf\n",
58
    "from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Activation, BatchNormalization, Add, Multiply\n",
59
    "from tensorflow.keras.models import Model\n",
60
    "import os\n",
61
    "import numpy as np\n",
62
    "from tensorflow.keras.preprocessing.image import load_img, img_to_array\n",
63
    "\n",
64
    "def attention_block(x, g, inter_channel):\n",
65
    "    \"\"\"\n",
66
    "    Attention Block: Refines encoder features based on decoder signals.\n",
67
    "    x: Input tensor from the encoder (skip connection)\n",
68
    "    g: Gating signal from the decoder (upsampled tensor)\n",
69
    "    inter_channel: Number of intermediate channels (reduces computation)\n",
70
    "    \"\"\"\n",
71
    "    # 1x1 Convolution on input tensor\n",
72
    "    theta_x = Conv2D(inter_channel, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)\n",
73
    "    # 1x1 Convolution on gating tensor\n",
74
    "    phi_g = Conv2D(inter_channel, kernel_size=(1, 1), strides=(1, 1), padding='same')(g)\n",
75
    "    \n",
76
    "    # Add the transformed inputs and apply ReLU\n",
77
    "    add_xg = Add()([theta_x, phi_g])\n",
78
    "    relu_xg = Activation('relu')(add_xg)\n",
79
    "    \n",
80
    "    # Another 1x1 Convolution to generate attention coefficients\n",
81
    "    psi = Conv2D(1, kernel_size=(1, 1), strides=(1, 1), padding='same')(relu_xg)\n",
82
    "    # Sigmoid activation to normalize attention weights\n",
83
    "    sigmoid_psi = Activation('sigmoid')(psi)\n",
84
    "    \n",
85
    "    # Multiply the input tensor with the attention weights\n",
86
    "    return Multiply()([x, sigmoid_psi])\n",
87
    "\n",
88
    "def conv_block(x, filters):\n",
89
    "    \"\"\"\n",
90
    "    Convolutional Block: Apply two 3x3 convolutions followed by BatchNorm and ReLU.\n",
91
    "    x: Input tensor\n",
92
    "    filters: Number of output filters for the convolutions\n",
93
    "    \"\"\"\n",
94
    "    x = Conv2D(filters, kernel_size=(3, 3), padding='same')(x)\n",
95
    "    x = BatchNormalization()(x)\n",
96
    "    x = Activation('relu')(x)\n",
97
    "    x = Conv2D(filters, kernel_size=(3, 3), padding='same')(x)\n",
98
    "    x = BatchNormalization()(x)\n",
99
    "    x = Activation('relu')(x)\n",
100
    "    return x\n",
101
    "\n",
102
    "def attention_unet(input_shape, num_classes):\n",
103
    "    \"\"\"\n",
104
    "    Attention U-Net model architecture.\n",
105
    "    input_shape: Shape of input images (H, W, C)\n",
106
    "    num_classes: Number of output segmentation classes\n",
107
    "    \"\"\"\n",
108
    "    # Input layer for the images\n",
109
    "    inputs = Input(input_shape)\n",
110
    "    \n",
111
    "    # Encoder (Downsampling path)\n",
112
    "    c1 = conv_block(inputs, 64)              # First Conv Block\n",
113
    "    p1 = MaxPooling2D((2, 2))(c1)            # Downsample by 2\n",
114
    "    \n",
115
    "    c2 = conv_block(p1, 128)                 # Second Conv Block\n",
116
    "    p2 = MaxPooling2D((2, 2))(c2)            # Downsample by 2\n",
117
    "    \n",
118
    "    c3 = conv_block(p2, 256)                 # Third Conv Block\n",
119
    "    p3 = MaxPooling2D((2, 2))(c3)            # Downsample by 2\n",
120
    "    \n",
121
    "    c4 = conv_block(p3, 512)                 # Fourth Conv Block\n",
122
    "    p4 = MaxPooling2D((2, 2))(c4)            # Downsample by 2\n",
123
    "    \n",
124
    "    # Bottleneck (lowest level of the U-Net)\n",
125
    "    c5 = conv_block(p4, 1024)\n",
126
    "    \n",
127
    "    # Decoder (Upsampling path)\n",
128
    "    up6 = UpSampling2D((2, 2))(c5)           # Upsample\n",
129
    "    att6 = attention_block(c4, up6, 512)     # Attention Block\n",
130
    "    merge6 = concatenate([up6, att6], axis=-1)  # Concatenate features\n",
131
    "    c6 = conv_block(merge6, 512)             # Conv Block after concatenation\n",
132
    "    \n",
133
    "    up7 = UpSampling2D((2, 2))(c6)\n",
134
    "    att7 = attention_block(c3, up7, 256)\n",
135
    "    merge7 = concatenate([up7, att7], axis=-1)\n",
136
    "    c7 = conv_block(merge7, 256)\n",
137
    "    \n",
138
    "    up8 = UpSampling2D((2, 2))(c7)\n",
139
    "    att8 = attention_block(c2, up8, 128)\n",
140
    "    merge8 = concatenate([up8, att8], axis=-1)\n",
141
    "    c8 = conv_block(merge8, 128)\n",
142
    "    \n",
143
    "    up9 = UpSampling2D((2, 2))(c8)\n",
144
    "    att9 = attention_block(c1, up9, 64)\n",
145
    "    merge9 = concatenate([up9, att9], axis=-1)\n",
146
    "    c9 = conv_block(merge9, 64)\n",
147
    "    \n",
148
    "    # Output layer for segmentation\n",
149
    "    outputs = Conv2D(num_classes, (1, 1), activation='softmax' if num_classes > 1 else 'sigmoid')(c9)\n",
150
    "    \n",
151
    "    # Define the model\n",
152
    "    model = Model(inputs=inputs, outputs=outputs)\n",
153
    "    return model\n",
154
    "\n",
155
    "# Function to load and preprocess images and masks\n",
156
    "def load_data(image_dir, mask_dir, image_size):\n",
157
    "    \"\"\"\n",
158
    "    Load and preprocess images and masks for training.\n",
159
    "    image_dir: Path to the directory containing input images\n",
160
    "    mask_dir: Path to the directory containing segmentation masks\n",
161
    "    image_size: Tuple specifying the size (height, width) to resize the images and masks\n",
162
    "    \"\"\"\n",
163
    "    images = []\n",
164
    "    masks = []\n",
165
    "    image_files = sorted(os.listdir(image_dir))\n",
166
    "    mask_files = sorted(os.listdir(mask_dir))\n",
167
    "    \n",
168
    "    for img_file, mask_file in zip(image_files, mask_files):\n",
169
    "        try:\n",
170
    "            # Load and preprocess images\n",
171
    "            img_path = os.path.join(image_dir, img_file)\n",
172
    "            mask_path = os.path.join(mask_dir, mask_file)\n",
173
    "            \n",
174
    "            img = load_img(img_path, target_size=image_size)  # Resize image\n",
175
    "            mask = load_img(mask_path, target_size=image_size, color_mode='grayscale')  # Resize mask\n",
176
    "            \n",
177
    "            # Convert to numpy arrays and normalize\n",
178
    "            img = img_to_array(img) / 255.0\n",
179
    "            mask = img_to_array(mask) / 255.0\n",
180
    "            mask = np.round(mask)  # Ensure masks are binary\n",
181
    "            \n",
182
    "            images.append(img)\n",
183
    "            masks.append(mask)\n",
184
    "        except Exception as e:\n",
185
    "            print(f\"Error loading {img_file} or {mask_file}: {e}. Skipping...\")\n",
186
    "    \n",
187
    "    return np.array(images), np.array(masks)\n",
188
    "\n",
189
    "# Example usage\n",
190
    "if __name__ == \"__main__\":\n",
191
    "    # Load data\n",
192
    "    image_dir = \"./images/\"  # Replace with your image directory\n",
193
    "    mask_dir = \"./masks/\"    # Replace with your mask directory\n",
194
    "    image_size = (128, 128)       # Resize all images to 128x128\n",
195
    "    images, masks = load_data(image_dir, mask_dir, image_size)\n",
196
    "    \n",
197
    "    # Define the model\n",
198
    "    model = attention_unet(input_shape=(128, 128, 3), num_classes=1)\n",
199
    "    \n",
200
    "    # Compile the model\n",
201
    "    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n",
202
    "    \n",
203
    "    # Train the model\n",
204
    "    model.fit(images, masks, batch_size=8, epochs=20, validation_split=0.1)"
205
   ]
206
  },
207
  {
208
   "cell_type": "code",
209
   "execution_count": null,
210
   "metadata": {},
211
   "outputs": [],
212
   "source": []
213
  }
214
 ],
215
 "metadata": {
216
  "kernelspec": {
217
   "display_name": "base",
218
   "language": "python",
219
   "name": "python3"
220
  },
221
  "language_info": {
222
   "codemirror_mode": {
223
    "name": "ipython",
224
    "version": 3
225
   },
226
   "file_extension": ".py",
227
   "mimetype": "text/x-python",
228
   "name": "python",
229
   "nbconvert_exporter": "python",
230
   "pygments_lexer": "ipython3",
231
   "version": "3.12.2"
232
  }
233
 },
234
 "nbformat": 4,
235
 "nbformat_minor": 2
236
}