Switch to side-by-side view

--- a
+++ b/GI-Tract-Image-Segmentation.py
@@ -0,0 +1,379 @@
+""" Import statements and check for GPU """
+
+import os
+import re
+import glob
+import math
+import cv2
+import csv 
+
+import pandas as pd
+import matplotlib.pyplot as plt
+import numpy as np
+
+from sklearn.model_selection import train_test_split
+from transunet import TransUNet
+
+import tensorflow as tf
+from tensorflow.keras.models import Model
+from tensorflow.keras.optimizers import Adam
+from tensorflow.keras.preprocessing.image import ImageDataGenerator
+from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
+
+from tensorflow import keras
+from tensorflow.keras import layers
+
+# List available GPUs
+gpus = tf.config.list_physical_devices('GPU')
+print("GPUs: ", gpus)
+
+if gpus:
+    print("TensorFlow is using the GPU.")
+else:
+    print("TensorFlow is not using the GPU.")
+
+
+
+
+
+""" Function Definitions """
+
+def rle_to_binary(rle, shape):
+    """
+    Decodes run length encoded masks into a binary image
+
+    Parameters:
+        rle (list): list containing the starts and lengths that make up each RLE mask
+        shape (tuple): the original shape of the associated image
+    """
+
+    # Initialize a flat mask with zeros
+    mask = np.zeros(shape[0] * shape[1], dtype=np.uint8)
+    
+    if rle == '' or rle == '0':  # Handle empty RLE
+        return mask.reshape(shape, order='C')
+
+    # Decode RLE into mask
+    rle_numbers = list(map(int, rle.split()))
+    for i in range(0, len(rle_numbers), 2):
+        start = rle_numbers[i] - 1  # Convert to zero-indexed
+        length = rle_numbers[i + 1]
+        mask[start:start + length] = 1
+
+    # Reshape flat mask into 2D
+    return mask.reshape(shape, order='C')
+
+
+
+def custom_generator(gdf, dir, batch_size, target_size=(224, 224), test_mode=False):
+    """
+    Custom data generator that dynamically aligns images and masks using RLE decoding.
+    
+    Parameters:
+        gdf (GroupBy): Grouped dataframe containing image IDs and RLEs.
+        dir (str): Root directory of the dataset.
+        batch_size (int): Number of samples per batch.
+        target_size (tuple): Target size for resizing (default=(224, 224)).
+        test_mode (bool): If True, yields one image and mask at a time.
+    """
+
+    ids = list(gdf.groups.keys())
+    dir2 = 'train'
+
+    while True:
+        sample_ids = np.random.choice(ids, size=batch_size, replace=False)
+        images, masks = [], []
+
+        for id_num in sample_ids:
+            # Get the dataframe rows for the current image
+            img_rows = gdf.get_group(id_num)
+            rle_list = img_rows['segmentation'].tolist()
+            
+            # Construct the file path for the image
+            sections = id_num.split('_')
+            case = sections[0]
+            day = sections[0] + '_' + sections[1]
+            slice_id = sections[2] + '_' + sections[3]
+            
+            pattern = os.path.join(dir, dir2, case, day, "scans", f"{slice_id}*.png")
+            filelist = glob.glob(pattern)
+            
+            if filelist:
+                file = filelist[0]
+                image = cv2.imread(file, cv2.IMREAD_COLOR)
+                if image is None:
+                    print(f"Image not found: {file}")
+                    continue  # Skip if the image is missing
+                
+                # Original shape of the image
+                original_shape = image.shape[:2]
+
+                # Resize the image
+                resized_image = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
+
+                # Decode and resize the masks
+                mask = np.zeros((target_size[0], target_size[1], len(rle_list)), dtype=np.uint8)
+                for i, rle in enumerate(rle_list):
+                    if rle != '0':  # Check if the RLE is valid
+                        decoded_mask = rle_to_binary(rle, original_shape)
+                        resized_mask = cv2.resize(decoded_mask, target_size, interpolation=cv2.INTER_NEAREST)
+                        mask[:, :, i] = resized_mask
+
+                if test_mode:
+                    # Return individual samples in test mode
+                    yield resized_image[np.newaxis], mask[np.newaxis], pattern
+                else:
+                    images.append(resized_image)
+                    masks.append(mask)
+
+        if not test_mode:
+            x = np.array(images)
+            y = np.array(masks)
+            yield x, y, None
+
+       
+
+
+
+""" Loss function: dice loss ignores negative class thus negating class imbalance issues """     
+
+def dice_coef(y_true, y_pred, smooth=1e-6):
+    # Ensure consistent data types
+    y_true = tf.cast(y_true, tf.float32)
+    y_pred = tf.cast(y_pred, tf.float32)
+
+    y_true_f = tf.keras.backend.flatten(y_true)
+    y_pred_f = tf.keras.backend.flatten(y_pred)
+    intersection = tf.reduce_sum(y_true_f * y_pred_f)
+    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
+
+def dice_loss(y_true, y_pred):
+    y_true = tf.cast(y_true, tf.float32)
+    y_pred = tf.cast(y_pred, tf.float32)
+    return 1 - dice_coef(y_true, y_pred)
+
+
+
+
+
+""" Construct pipeline """
+
+# dir = '../path/Dataset'
+dir = './Dataset'
+
+target_size = 224
+batch_size = 24
+epochs = 124
+
+# read the csv file into a dataframe. os.path.join makes code executable across operating systes
+df = pd.read_csv(os.path.join('.', dir, 'train.csv'))
+df['segmentation'] = df['segmentation'].fillna('0')
+
+# split into training, testing and validation sets
+train_ids, temp_ids = train_test_split(df.id.unique(), test_size=0.25, random_state=42)
+val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)
+
+# convert dfs into groupby objects to make sure rows are grouped by id
+train_grouped_df = df[df.id.isin(train_ids)].groupby('id')
+val_grouped_df = df[df.id.isin(val_ids)].groupby('id')
+test_grouped_df = df[df.id.isin(test_ids)].groupby('id')
+
+
+# steps per epoch is typically train length / batch size to use all training examples
+train_steps_per_epoch = math.ceil(len(train_ids) / batch_size)
+val_steps_per_epoch = math.ceil(len(val_ids) / batch_size)
+test_steps_per_epoch = math.ceil(len(test_ids) / batch_size)
+
+# create the training and validation datagens
+train_generator = custom_generator(train_grouped_df, dir, batch_size, (target_size, target_size))
+val_generator = custom_generator(val_grouped_df, dir, batch_size, (target_size, target_size))
+test_generator = custom_generator(test_grouped_df, dir, batch_size, (target_size, target_size), test_mode=True)
+
+
+
+
+
+""" Build the model or load the trained model """
+
+loading = True
+
+if loading:
+    weights_path = './impmodels/model_weights.h5'
+    model = TransUNet(image_size=224, pretrain=False)
+    model.load_weights(weights_path)
+    model.compile(optimizer='adam', loss=dice_loss, metrics=['accuracy'])
+else:
+    # create the optimizer and learning rate scheduler
+    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
+        initial_learning_rate = 1e-3,
+        # decay_steps=train_steps_per_epoch * epochs,
+        decay_steps=epochs+2,
+        alpha=1e-2 
+    )
+
+    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
+
+    # create the U-net neural network
+    model = TransUNet(image_size=target_size, freeze_enc_cnn=False, pretrain=True)
+    model.compile(optimizer=optimizer, loss=dice_loss, metrics=['accuracy'])
+
+    # set up model checkpoints and early stopping
+    checkpoints_path = os.path.join('Checkpoints', 'model_weights.h5')
+    model_checkpoint = ModelCheckpoint(filepath=checkpoints_path, save_best_only=True, monitor='val_loss')
+    early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=8)
+
+    # log the training to a .csv for reference
+    csv_logger = CSVLogger('training_log.csv', append=True)
+
+    history = model.fit(train_generator, validation_data=val_generator, steps_per_epoch=train_steps_per_epoch, validation_steps=val_steps_per_epoch, epochs=epochs, callbacks=[model_checkpoint, early_stopping, csv_logger])
+
+
+
+
+
+""" Display some predictions """
+
+preds = []
+ground_truths = []
+num_samples = 50
+
+# Generate predictions and ground truths
+for i in range(num_samples):
+    # Fetch a batch from the test generator
+    batch = next(test_generator)
+    image, mask = batch  
+    
+    preds.append(model.predict(image))  # Predict using the model
+    ground_truths.append(mask)
+
+best_threshold = 0.99
+
+# Apply the best threshold to all predictions
+final_preds = [(pred >= best_threshold).astype(int) for pred in preds]
+
+# Compute Dice loss for each prediction
+for i in range(len(final_preds)):
+    loss = dice_loss(ground_truths[i], final_preds[i])  
+    print(f"Image {i + 1}: Dice Loss = {loss:.4f}")
+
+
+
+def visualize_predictions(generator, model, num_samples=8, target_size=(224, 224)):
+    """
+    Visualize predictions vs. ground truths overlaid on original images.
+
+    Parameters:
+        generator (generator): Data generator
+        model (Model): Trained segmentation model
+        num_samples (int): Number of samples to visualize
+        target_size (tuple): Target size for resizing (default=(224, 224)).
+    """
+    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
+
+    for i in range(num_samples):
+        # Fetch one image and mask from the generator
+        image_batch, mask_batch = next(generator)
+        image = image_batch[0]  # Single image
+        ground_truth = mask_batch[0]  # Corresponding ground truth mask
+
+        # Ensure image is RGB
+        if len(image.shape) == 2:
+            image = np.stack([image] * 3, axis=-1)  # Convert grayscale to RGB
+
+        # Ensure ground truth is a single-channel binary mask
+        if ground_truth.ndim == 3 and ground_truth.shape[-1] == 3:
+            ground_truth = ground_truth[:, :, 0]  # Extract the first channel
+
+        # Generate prediction
+        raw_prediction = model.predict(image[np.newaxis])[0]  # Add batch dimension for prediction
+
+        # Ensure prediction is single-channel
+        if raw_prediction.ndim == 3 and raw_prediction.shape[-1] == 3:
+            prediction = raw_prediction[:, :, 0]  # Extract the first channel
+        else:
+            prediction = raw_prediction
+        prediction = (prediction >= 0.99).astype(np.uint8)  # Threshold prediction
+
+        # Create overlays
+        gt_overlay = image.copy()
+        pred_overlay = image.copy()
+
+        # Overlay ground truth in red
+        gt_overlay[ground_truth == 1] = [255, 0, 0]
+
+        # Overlay prediction in green
+        pred_overlay[prediction == 1] = [0, 255, 0]
+
+        # Plot original image, ground truth overlay, and prediction overlay
+        axes[i, 0].imshow(image)
+        axes[i, 0].set_title(f"Image {i + 1}")
+        axes[i, 0].axis('off')
+
+        axes[i, 1].imshow(gt_overlay)
+        axes[i, 1].set_title(f"Ground Truth Overlay {i + 1}")
+        axes[i, 1].axis('off')
+
+        axes[i, 2].imshow(pred_overlay)
+        axes[i, 2].set_title(f"Prediction Overlay {i + 1}")
+        axes[i, 2].axis('off')
+
+    plt.tight_layout()
+    plt.show()
+
+
+# Call the function with your test generator and trained model
+visualize_predictions(test_generator, model, num_samples=24)
+
+
+
+
+
+def binary_to_rle(binary_mask):
+    """
+    Converts a binary mask to RLE (Run-Length Encoding).
+    """
+    # Flatten mask in column-major order
+    flat_mask = binary_mask.T.flatten()
+
+    rle = []
+    start = -1
+    for i, val in enumerate(flat_mask):
+        if val == 1 and start == -1:
+            start = i
+        elif val == 0 and start != -1:
+            rle.extend([start + 1, i - start])
+            start = -1
+    if start != -1:
+        rle.extend([start + 1, len(flat_mask) - start])
+    
+    return ' '.join(map(str, rle))
+
+
+
+def save_predictions_to_csv(test_generator, model, output_csv_path):
+    """
+    Generates predictions using the trained model and writes them to a CSV file in RLE format.
+    
+    Parameters:
+        test_generator: The data generator for the test set.
+        model: The trained segmentation model.
+        output_csv_path: Path to save the CSV file.
+    """
+    
+    with open(output_csv_path, mode='w', newline='') as csvfile:
+        csv_writer = csv.writer(csvfile)
+        csv_writer.writerow(['id', 'segmentation'])  # Header row
+
+        for image, masks, ids in test_generator:
+            predictions = model.predict(image)
+            predictions = (predictions > 0.99).astype(int)  
+
+            for pred_mask, mask_id in zip(predictions, ids):
+                rle = binary_to_rle(pred_mask.squeeze())
+                csv_writer.writerow([mask_id, rle])
+
+            print(f"Processed {len(ids)} predictions...")
+
+
+
+save_predictions_to_csv(test_generator, model, 'model_output.csv')
\ No newline at end of file