In [3]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Activation, BatchNormalization, Add, Multiply
from tensorflow.keras.models import Model
import os
import numpy as np
from tensorflow.keras.preprocessing.image import load_img, img_to_array

def attention_block(x, g, inter_channel):
    """
    Attention Block: Refines encoder features based on decoder signals.
    x: Input tensor from the encoder (skip connection)
    g: Gating signal from the decoder (upsampled tensor)
    inter_channel: Number of intermediate channels (reduces computation)
    """
    # 1x1 Convolution on input tensor
    theta_x = Conv2D(inter_channel, kernel_size=(1, 1), strides=(1, 1), padding='same')(x)
    # 1x1 Convolution on gating tensor
    phi_g = Conv2D(inter_channel, kernel_size=(1, 1), strides=(1, 1), padding='same')(g)
    
    # Add the transformed inputs and apply ReLU
    add_xg = Add()([theta_x, phi_g])
    relu_xg = Activation('relu')(add_xg)
    
    # Another 1x1 Convolution to generate attention coefficients
    psi = Conv2D(1, kernel_size=(1, 1), strides=(1, 1), padding='same')(relu_xg)
    # Sigmoid activation to normalize attention weights
    sigmoid_psi = Activation('sigmoid')(psi)
    
    # Multiply the input tensor with the attention weights
    return Multiply()([x, sigmoid_psi])

def conv_block(x, filters):
    """
    Convolutional Block: Apply two 3x3 convolutions followed by BatchNorm and ReLU.
    x: Input tensor
    filters: Number of output filters for the convolutions
    """
    x = Conv2D(filters, kernel_size=(3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(filters, kernel_size=(3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x

def attention_unet(input_shape, num_classes):
    """
    Attention U-Net model architecture.
    input_shape: Shape of input images (H, W, C)
    num_classes: Number of output segmentation classes
    """
    # Input layer for the images
    inputs = Input(input_shape)
    
    # Encoder (Downsampling path)
    c1 = conv_block(inputs, 64)              # First Conv Block
    p1 = MaxPooling2D((2, 2))(c1)            # Downsample by 2
    
    c2 = conv_block(p1, 128)                 # Second Conv Block
    p2 = MaxPooling2D((2, 2))(c2)            # Downsample by 2
    
    c3 = conv_block(p2, 256)                 # Third Conv Block
    p3 = MaxPooling2D((2, 2))(c3)            # Downsample by 2
    
    c4 = conv_block(p3, 512)                 # Fourth Conv Block
    p4 = MaxPooling2D((2, 2))(c4)            # Downsample by 2
    
    # Bottleneck (lowest level of the U-Net)
    c5 = conv_block(p4, 1024)
    
    # Decoder (Upsampling path)
    up6 = UpSampling2D((2, 2))(c5)           # Upsample
    att6 = attention_block(c4, up6, 512)     # Attention Block
    merge6 = concatenate([up6, att6], axis=-1)  # Concatenate features
    c6 = conv_block(merge6, 512)             # Conv Block after concatenation
    
    up7 = UpSampling2D((2, 2))(c6)
    att7 = attention_block(c3, up7, 256)
    merge7 = concatenate([up7, att7], axis=-1)
    c7 = conv_block(merge7, 256)
    
    up8 = UpSampling2D((2, 2))(c7)
    att8 = attention_block(c2, up8, 128)
    merge8 = concatenate([up8, att8], axis=-1)
    c8 = conv_block(merge8, 128)
    
    up9 = UpSampling2D((2, 2))(c8)
    att9 = attention_block(c1, up9, 64)
    merge9 = concatenate([up9, att9], axis=-1)
    c9 = conv_block(merge9, 64)
    
    # Output layer for segmentation
    outputs = Conv2D(num_classes, (1, 1), activation='softmax' if num_classes > 1 else 'sigmoid')(c9)
    
    # Define the model
    model = Model(inputs=inputs, outputs=outputs)
    return model

# Function to load and preprocess images and masks
def load_data(image_dir, mask_dir, image_size):
    """
    Load and preprocess images and masks for training.
    image_dir: Path to the directory containing input images
    mask_dir: Path to the directory containing segmentation masks
    image_size: Tuple specifying the size (height, width) to resize the images and masks
    """
    images = []
    masks = []
    image_files = sorted(os.listdir(image_dir))
    mask_files = sorted(os.listdir(mask_dir))
    
    for img_file, mask_file in zip(image_files, mask_files):
        try:
            # Load and preprocess images
            img_path = os.path.join(image_dir, img_file)
            mask_path = os.path.join(mask_dir, mask_file)
            
            img = load_img(img_path, target_size=image_size)  # Resize image
            mask = load_img(mask_path, target_size=image_size, color_mode='grayscale')  # Resize mask
            
            # Convert to numpy arrays and normalize
            img = img_to_array(img) / 255.0
            mask = img_to_array(mask) / 255.0
            mask = np.round(mask)  # Ensure masks are binary
            
            images.append(img)
            masks.append(mask)
        except Exception as e:
            print(f"Error loading {img_file} or {mask_file}: {e}. Skipping...")
    
    return np.array(images), np.array(masks)

# Example usage
if __name__ == "__main__":
    # Load data
    image_dir = "./images/"  # Replace with your image directory
    mask_dir = "./masks/"    # Replace with your mask directory
    image_size = (128, 128)       # Resize all images to 128x128
    images, masks = load_data(image_dir, mask_dir, image_size)
    
    # Define the model
    model = attention_unet(input_shape=(128, 128, 3), num_classes=1)
    
    # Compile the model
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # Train the model
    model.fit(images, masks, batch_size=8, epochs=20, validation_split=0.1)

Error loading .DS_Store or 0655[0]_47.png: cannot identify image file <_io.BytesIO object at 0x35adee660>. Skipping...
Epoch 1/20
[1m193/193[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m384s[0m 2s/step - accuracy: 0.9061 - loss: 0.2485 - val_accuracy: 0.8808 - val_loss: 0.3486
Epoch 2/20
[1m193/193[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m384s[0m 2s/step - accuracy: 0.9415 - loss: 0.1394 - val_accuracy: 0.8412 - val_loss: 0.4048
Epoch 3/20
[1m193/193[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m378s[0m 2s/step - accuracy: 0.9457 - loss: 0.1280 - val_accuracy: 0.8718 - val_loss: 0.4388
Epoch 4/20
[1m193/193[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m385s[0m 2s/step - accuracy: 0.9491 - loss: 0.1193 - val_accuracy: 0.8620 - val_loss: 0.4341
Epoch 5/20
[1m193/193[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m378s[0m 2s/step - accuracy: 0.9492 - loss: 0.1185 - val_accuracy: 0.8636 - val_loss: 0.5675
Epoch 6/20
[1m193/193[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m 