|
a |
|
b/AttentonUnet.ipynb |
|
|
1 |
{ |
|
|
2 |
"cells": [ |
|
|
3 |
{ |
|
|
4 |
"cell_type": "code", |
|
|
5 |
"execution_count": 3, |
|
|
6 |
"metadata": {}, |
|
|
7 |
"outputs": [ |
|
|
8 |
{ |
|
|
9 |
"name": "stdout", |
|
|
10 |
"output_type": "stream", |
|
|
11 |
"text": [ |
|
|
12 |
"Error loading .DS_Store or 0655[0]_47.png: cannot identify image file <_io.BytesIO object at 0x35adee660>. Skipping...\n", |
|
|
13 |
"Epoch 1/20\n", |
|
|
14 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m384s\u001b[0m 2s/step - accuracy: 0.9061 - loss: 0.2485 - val_accuracy: 0.8808 - val_loss: 0.3486\n", |
|
|
15 |
"Epoch 2/20\n", |
|
|
16 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m384s\u001b[0m 2s/step - accuracy: 0.9415 - loss: 0.1394 - val_accuracy: 0.8412 - val_loss: 0.4048\n", |
|
|
17 |
"Epoch 3/20\n", |
|
|
18 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m378s\u001b[0m 2s/step - accuracy: 0.9457 - loss: 0.1280 - val_accuracy: 0.8718 - val_loss: 0.4388\n", |
|
|
19 |
"Epoch 4/20\n", |
|
|
20 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m385s\u001b[0m 2s/step - accuracy: 0.9491 - loss: 0.1193 - val_accuracy: 0.8620 - val_loss: 0.4341\n", |
|
|
21 |
"Epoch 5/20\n", |
|
|
22 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m378s\u001b[0m 2s/step - accuracy: 0.9492 - loss: 0.1185 - val_accuracy: 0.8636 - val_loss: 0.5675\n", |
|
|
23 |
"Epoch 6/20\n", |
|
|
24 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m384s\u001b[0m 2s/step - accuracy: 0.9515 - loss: 0.1134 - val_accuracy: 0.8706 - val_loss: 0.5460\n", |
|
|
25 |
"Epoch 7/20\n", |
|
|
26 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m384s\u001b[0m 2s/step - accuracy: 0.9568 - loss: 0.0998 - val_accuracy: 0.8562 - val_loss: 0.6479\n", |
|
|
27 |
"Epoch 8/20\n", |
|
|
28 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m382s\u001b[0m 2s/step - accuracy: 0.9572 - loss: 0.0983 - val_accuracy: 0.8637 - val_loss: 1.0583\n", |
|
|
29 |
"Epoch 9/20\n", |
|
|
30 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m391s\u001b[0m 2s/step - accuracy: 0.9601 - loss: 0.0928 - val_accuracy: 0.8689 - val_loss: 0.4872\n", |
|
|
31 |
"Epoch 10/20\n", |
|
|
32 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m385s\u001b[0m 2s/step - accuracy: 0.9616 - loss: 0.0885 - val_accuracy: 0.8676 - val_loss: 0.6407\n", |
|
|
33 |
"Epoch 11/20\n", |
|
|
34 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m373s\u001b[0m 2s/step - accuracy: 0.9648 - loss: 0.0807 - val_accuracy: 0.8683 - val_loss: 0.6889\n", |
|
|
35 |
"Epoch 12/20\n", |
|
|
36 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m377s\u001b[0m 2s/step - accuracy: 0.9663 - loss: 0.0786 - val_accuracy: 0.8550 - val_loss: 0.7435\n", |
|
|
37 |
"Epoch 13/20\n", |
|
|
38 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m374s\u001b[0m 2s/step - accuracy: 0.9703 - loss: 0.0703 - val_accuracy: 0.8677 - val_loss: 0.6834\n", |
|
|
39 |
"Epoch 14/20\n", |
|
|
40 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m373s\u001b[0m 2s/step - accuracy: 0.9712 - loss: 0.0665 - val_accuracy: 0.8694 - val_loss: 0.5149\n", |
|
|
41 |
"Epoch 15/20\n", |
|
|
42 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m379s\u001b[0m 2s/step - accuracy: 0.9716 - loss: 0.0672 - val_accuracy: 0.8633 - val_loss: 0.7259\n", |
|
|
43 |
"Epoch 16/20\n", |
|
|
44 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m395s\u001b[0m 2s/step - accuracy: 0.9748 - loss: 0.0594 - val_accuracy: 0.8736 - val_loss: 0.6896\n", |
|
|
45 |
"Epoch 17/20\n", |
|
|
46 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m380s\u001b[0m 2s/step - accuracy: 0.9767 - loss: 0.0545 - val_accuracy: 0.8695 - val_loss: 0.7535\n", |
|
|
47 |
"Epoch 18/20\n", |
|
|
48 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m389s\u001b[0m 2s/step - accuracy: 0.9773 - loss: 0.0532 - val_accuracy: 0.8664 - val_loss: 0.8831\n", |
|
|
49 |
"Epoch 19/20\n", |
|
|
50 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m376s\u001b[0m 2s/step - accuracy: 0.9781 - loss: 0.0512 - val_accuracy: 0.8720 - val_loss: 0.7170\n", |
|
|
51 |
"Epoch 20/20\n", |
|
|
52 |
"\u001b[1m193/193\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m384s\u001b[0m 2s/step - accuracy: 0.9790 - loss: 0.0487 - val_accuracy: 0.8707 - val_loss: 0.6628\n" |
|
|
53 |
] |
|
|
54 |
} |
|
|
55 |
], |
|
|
56 |
"source": [ |
|
|
57 |
"import tensorflow as tf\n", |
|
|
58 |
"from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Activation, BatchNormalization, Add, Multiply\n", |
|
|
59 |
"from tensorflow.keras.models import Model\n", |
|
|
60 |
"import os\n", |
|
|
61 |
"import numpy as np\n", |
|
|
62 |
"from tensorflow.keras.preprocessing.image import load_img, img_to_array\n", |
|
|
63 |
"\n", |
|
|
64 |
"def attention_block(x, g, inter_channel):\n", |
|
|
65 |
" \"\"\"\n", |
|
|
66 |
" Attention Block: Refines encoder features based on decoder signals.\n", |
|
|
67 |
" x: Input tensor from the encoder (skip connection)\n", |
|
|
68 |
" g: Gating signal from the decoder (upsampled tensor)\n", |
|
|
69 |
" inter_channel: Number of intermediate channels (reduces computation)\n", |
|
|
70 |
" \"\"\"\n", |
|
|
71 |
" # 1x1 Convolution on input tensor\n", |
|
|
72 |
" theta_x = Conv2D(inter_channel, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)\n", |
|
|
73 |
" # 1x1 Convolution on gating tensor\n", |
|
|
74 |
" phi_g = Conv2D(inter_channel, kernel_size=(1, 1), strides=(1, 1), padding='same')(g)\n", |
|
|
75 |
" \n", |
|
|
76 |
" # Add the transformed inputs and apply ReLU\n", |
|
|
77 |
" add_xg = Add()([theta_x, phi_g])\n", |
|
|
78 |
" relu_xg = Activation('relu')(add_xg)\n", |
|
|
79 |
" \n", |
|
|
80 |
" # Another 1x1 Convolution to generate attention coefficients\n", |
|
|
81 |
" psi = Conv2D(1, kernel_size=(1, 1), strides=(1, 1), padding='same')(relu_xg)\n", |
|
|
82 |
" # Sigmoid activation to normalize attention weights\n", |
|
|
83 |
" sigmoid_psi = Activation('sigmoid')(psi)\n", |
|
|
84 |
" \n", |
|
|
85 |
" # Multiply the input tensor with the attention weights\n", |
|
|
86 |
" return Multiply()([x, sigmoid_psi])\n", |
|
|
87 |
"\n", |
|
|
88 |
"def conv_block(x, filters):\n", |
|
|
89 |
" \"\"\"\n", |
|
|
90 |
" Convolutional Block: Apply two 3x3 convolutions followed by BatchNorm and ReLU.\n", |
|
|
91 |
" x: Input tensor\n", |
|
|
92 |
" filters: Number of output filters for the convolutions\n", |
|
|
93 |
" \"\"\"\n", |
|
|
94 |
" x = Conv2D(filters, kernel_size=(3, 3), padding='same')(x)\n", |
|
|
95 |
" x = BatchNormalization()(x)\n", |
|
|
96 |
" x = Activation('relu')(x)\n", |
|
|
97 |
" x = Conv2D(filters, kernel_size=(3, 3), padding='same')(x)\n", |
|
|
98 |
" x = BatchNormalization()(x)\n", |
|
|
99 |
" x = Activation('relu')(x)\n", |
|
|
100 |
" return x\n", |
|
|
101 |
"\n", |
|
|
102 |
"def attention_unet(input_shape, num_classes):\n", |
|
|
103 |
" \"\"\"\n", |
|
|
104 |
" Attention U-Net model architecture.\n", |
|
|
105 |
" input_shape: Shape of input images (H, W, C)\n", |
|
|
106 |
" num_classes: Number of output segmentation classes\n", |
|
|
107 |
" \"\"\"\n", |
|
|
108 |
" # Input layer for the images\n", |
|
|
109 |
" inputs = Input(input_shape)\n", |
|
|
110 |
" \n", |
|
|
111 |
" # Encoder (Downsampling path)\n", |
|
|
112 |
" c1 = conv_block(inputs, 64) # First Conv Block\n", |
|
|
113 |
" p1 = MaxPooling2D((2, 2))(c1) # Downsample by 2\n", |
|
|
114 |
" \n", |
|
|
115 |
" c2 = conv_block(p1, 128) # Second Conv Block\n", |
|
|
116 |
" p2 = MaxPooling2D((2, 2))(c2) # Downsample by 2\n", |
|
|
117 |
" \n", |
|
|
118 |
" c3 = conv_block(p2, 256) # Third Conv Block\n", |
|
|
119 |
" p3 = MaxPooling2D((2, 2))(c3) # Downsample by 2\n", |
|
|
120 |
" \n", |
|
|
121 |
" c4 = conv_block(p3, 512) # Fourth Conv Block\n", |
|
|
122 |
" p4 = MaxPooling2D((2, 2))(c4) # Downsample by 2\n", |
|
|
123 |
" \n", |
|
|
124 |
" # Bottleneck (lowest level of the U-Net)\n", |
|
|
125 |
" c5 = conv_block(p4, 1024)\n", |
|
|
126 |
" \n", |
|
|
127 |
" # Decoder (Upsampling path)\n", |
|
|
128 |
" up6 = UpSampling2D((2, 2))(c5) # Upsample\n", |
|
|
129 |
" att6 = attention_block(c4, up6, 512) # Attention Block\n", |
|
|
130 |
" merge6 = concatenate([up6, att6], axis=-1) # Concatenate features\n", |
|
|
131 |
" c6 = conv_block(merge6, 512) # Conv Block after concatenation\n", |
|
|
132 |
" \n", |
|
|
133 |
" up7 = UpSampling2D((2, 2))(c6)\n", |
|
|
134 |
" att7 = attention_block(c3, up7, 256)\n", |
|
|
135 |
" merge7 = concatenate([up7, att7], axis=-1)\n", |
|
|
136 |
" c7 = conv_block(merge7, 256)\n", |
|
|
137 |
" \n", |
|
|
138 |
" up8 = UpSampling2D((2, 2))(c7)\n", |
|
|
139 |
" att8 = attention_block(c2, up8, 128)\n", |
|
|
140 |
" merge8 = concatenate([up8, att8], axis=-1)\n", |
|
|
141 |
" c8 = conv_block(merge8, 128)\n", |
|
|
142 |
" \n", |
|
|
143 |
" up9 = UpSampling2D((2, 2))(c8)\n", |
|
|
144 |
" att9 = attention_block(c1, up9, 64)\n", |
|
|
145 |
" merge9 = concatenate([up9, att9], axis=-1)\n", |
|
|
146 |
" c9 = conv_block(merge9, 64)\n", |
|
|
147 |
" \n", |
|
|
148 |
" # Output layer for segmentation\n", |
|
|
149 |
" outputs = Conv2D(num_classes, (1, 1), activation='softmax' if num_classes > 1 else 'sigmoid')(c9)\n", |
|
|
150 |
" \n", |
|
|
151 |
" # Define the model\n", |
|
|
152 |
" model = Model(inputs=inputs, outputs=outputs)\n", |
|
|
153 |
" return model\n", |
|
|
154 |
"\n", |
|
|
155 |
"# Function to load and preprocess images and masks\n", |
|
|
156 |
"def load_data(image_dir, mask_dir, image_size):\n", |
|
|
157 |
" \"\"\"\n", |
|
|
158 |
" Load and preprocess images and masks for training.\n", |
|
|
159 |
" image_dir: Path to the directory containing input images\n", |
|
|
160 |
" mask_dir: Path to the directory containing segmentation masks\n", |
|
|
161 |
" image_size: Tuple specifying the size (height, width) to resize the images and masks\n", |
|
|
162 |
" \"\"\"\n", |
|
|
163 |
" images = []\n", |
|
|
164 |
" masks = []\n", |
|
|
165 |
" image_files = sorted(os.listdir(image_dir))\n", |
|
|
166 |
" mask_files = sorted(os.listdir(mask_dir))\n", |
|
|
167 |
" \n", |
|
|
168 |
" for img_file, mask_file in zip(image_files, mask_files):\n", |
|
|
169 |
" try:\n", |
|
|
170 |
" # Load and preprocess images\n", |
|
|
171 |
" img_path = os.path.join(image_dir, img_file)\n", |
|
|
172 |
" mask_path = os.path.join(mask_dir, mask_file)\n", |
|
|
173 |
" \n", |
|
|
174 |
" img = load_img(img_path, target_size=image_size) # Resize image\n", |
|
|
175 |
" mask = load_img(mask_path, target_size=image_size, color_mode='grayscale') # Resize mask\n", |
|
|
176 |
" \n", |
|
|
177 |
" # Convert to numpy arrays and normalize\n", |
|
|
178 |
" img = img_to_array(img) / 255.0\n", |
|
|
179 |
" mask = img_to_array(mask) / 255.0\n", |
|
|
180 |
" mask = np.round(mask) # Ensure masks are binary\n", |
|
|
181 |
" \n", |
|
|
182 |
" images.append(img)\n", |
|
|
183 |
" masks.append(mask)\n", |
|
|
184 |
" except Exception as e:\n", |
|
|
185 |
" print(f\"Error loading {img_file} or {mask_file}: {e}. Skipping...\")\n", |
|
|
186 |
" \n", |
|
|
187 |
" return np.array(images), np.array(masks)\n", |
|
|
188 |
"\n", |
|
|
189 |
"# Example usage\n", |
|
|
190 |
"if __name__ == \"__main__\":\n", |
|
|
191 |
" # Load data\n", |
|
|
192 |
" image_dir = \"./images/\" # Replace with your image directory\n", |
|
|
193 |
" mask_dir = \"./masks/\" # Replace with your mask directory\n", |
|
|
194 |
" image_size = (128, 128) # Resize all images to 128x128\n", |
|
|
195 |
" images, masks = load_data(image_dir, mask_dir, image_size)\n", |
|
|
196 |
" \n", |
|
|
197 |
" # Define the model\n", |
|
|
198 |
" model = attention_unet(input_shape=(128, 128, 3), num_classes=1)\n", |
|
|
199 |
" \n", |
|
|
200 |
" # Compile the model\n", |
|
|
201 |
" model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n", |
|
|
202 |
" \n", |
|
|
203 |
" # Train the model\n", |
|
|
204 |
" model.fit(images, masks, batch_size=8, epochs=20, validation_split=0.1)" |
|
|
205 |
] |
|
|
206 |
}, |
|
|
207 |
{ |
|
|
208 |
"cell_type": "code", |
|
|
209 |
"execution_count": null, |
|
|
210 |
"metadata": {}, |
|
|
211 |
"outputs": [], |
|
|
212 |
"source": [] |
|
|
213 |
} |
|
|
214 |
], |
|
|
215 |
"metadata": { |
|
|
216 |
"kernelspec": { |
|
|
217 |
"display_name": "base", |
|
|
218 |
"language": "python", |
|
|
219 |
"name": "python3" |
|
|
220 |
}, |
|
|
221 |
"language_info": { |
|
|
222 |
"codemirror_mode": { |
|
|
223 |
"name": "ipython", |
|
|
224 |
"version": 3 |
|
|
225 |
}, |
|
|
226 |
"file_extension": ".py", |
|
|
227 |
"mimetype": "text/x-python", |
|
|
228 |
"name": "python", |
|
|
229 |
"nbconvert_exporter": "python", |
|
|
230 |
"pygments_lexer": "ipython3", |
|
|
231 |
"version": "3.12.2" |
|
|
232 |
} |
|
|
233 |
}, |
|
|
234 |
"nbformat": 4, |
|
|
235 |
"nbformat_minor": 2 |
|
|
236 |
} |