In [1]:
#imports
import os
import re
import glob
import math
import cv2
import gc
import csv 


import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from sklearn.model_selection import train_test_split
from transunet import TransUNet

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger

from tensorflow import keras
from tensorflow.keras import layers

# List available GPUs
gpus = tf.config.list_physical_devices('GPU')
print("GPUs: ", gpus)

if gpus:
    print("TensorFlow is using the GPU.")
else:
    print("TensorFlow is not using the GPU.")
2024-12-04 19:46:47.325076: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
GPUs:  []
TensorFlow is not using the GPU.
/home/des/anaconda3/lib/python3.10/site-packages/tensorflow_addons/utils/tfa_eol_msg.py:23: UserWarning: 

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 

  warnings.warn(
In [2]:
def rle_to_binary(rle, shape):
    """
    Converts a RLE (run length encoding) to a binary mask.
    """

    # Initialize a flat mask with zeros
    mask = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    
    if rle == '' or rle == '0':  # Handle empty RLE
        return mask.reshape(shape, order='C')

    # Decode RLE into mask
    rle_numbers = list(map(int, rle.split()))
    for i in range(0, len(rle_numbers), 2):
        start = rle_numbers[i] - 1  # Convert to zero-indexed
        length = rle_numbers[i + 1]
        mask[start:start + length] = 1

    # Reshape flat mask into 2D
    return mask.reshape(shape, order='C')



def custom_generator(gdf, dir, batch_size, target_size=(224, 224), test_mode=False):
    """
    Custom data generator that dynamically aligns images and masks using RLE decoding.
    
    Parameters:
        gdf (GroupBy): Grouped dataframe containing image IDs and RLEs.
        dir (str): Root directory of the dataset.
        batch_size (int): Number of samples per batch.
        target_size (tuple): Target size for resizing (default=(224, 224)).
        test_mode (bool): If True, yields one image and mask at a time.
    """
    
    ids = list(gdf.groups.keys())
    dir2 = 'train'

    while True:
        sample_ids = np.random.choice(ids, size=batch_size, replace=False)
        images, masks = [], []

        for id_num in sample_ids:
            # Get the dataframe rows for the current image
            img_rows = gdf.get_group(id_num)
            rle_list = img_rows['segmentation'].tolist()
            
            # Construct the file path for the image
            sections = id_num.split('_')
            case = sections[0]
            day = sections[0] + '_' + sections[1]
            slice_id = sections[2] + '_' + sections[3]
            
            pattern = os.path.join(dir, dir2, case, day, "scans", f"{slice_id}*.png")
            filelist = glob.glob(pattern)
            
            if filelist:
                file = filelist[0]
                image = cv2.imread(file, cv2.IMREAD_COLOR)
                if image is None:
                    print(f"Image not found: {file}")
                    continue  # Skip if the image is missing
                
                # Original shape of the image
                original_shape = image.shape[:2]

                # Resize the image
                resized_image = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)

                # Decode and resize the masks
                mask = np.zeros((target_size[0], target_size[1], len(rle_list)), dtype=np.uint8)
                for i, rle in enumerate(rle_list):
                    if rle != '0':  # Check if the RLE is valid
                        decoded_mask = rle_to_binary(rle, original_shape)
                        resized_mask = cv2.resize(decoded_mask, target_size, interpolation=cv2.INTER_NEAREST)
                        mask[:, :, i] = resized_mask

                if test_mode:
                    # Return individual samples in test mode
                    yield resized_image[np.newaxis], mask[np.newaxis], pattern
                else:
                    images.append(resized_image)
                    masks.append(mask)

        if not test_mode:
            x = np.array(images)
            y = np.array(masks)
            yield x, y

            
            
            
def dice_coef(y_true, y_pred, smooth=1e-6):
    # Ensure consistent data types
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    return 1 - dice_coef(y_true, y_pred)
In [3]:
dir = './Dataset'
In [4]:
df = pd.read_csv(os.path.join('.', dir, 'train.csv'))
df['segmentation'] = df['segmentation'].fillna('0')

train_ids, temp_ids = train_test_split(df.id.unique(), test_size=0.1, random_state=42)
val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)

train_grouped_df = df[df.id.isin(train_ids)].groupby('id')
val_grouped_df = df[df.id.isin(val_ids)].groupby('id')
test_grouped_df = df[df.id.isin(test_ids)].groupby('id')

batch_size = 24
target_size = 224
epochs = 8

# steps per epoch is typically train length / batch size to use all training examples
train_steps_per_epoch = math.ceil(len(train_ids) / batch_size)
val_steps_per_epoch = math.ceil(len(val_ids) / batch_size)
test_steps_per_epoch = math.ceil(len(test_ids) / batch_size)

# create the training and validation datagens
train_generator = custom_generator(train_grouped_df, dir, batch_size, (target_size, target_size))
val_generator = custom_generator(val_grouped_df, dir, batch_size, (target_size, target_size))
test_generator = custom_generator(test_grouped_df, dir, batch_size, (target_size, target_size), test_mode=True)
In [5]:
loading = True

if loading:
    model = TransUNet(image_size=224, pretrain=False)
    model.load_weights('./impmodels/model_weights.h5')
    model.compile(optimizer='adam', loss=dice_loss, metrics=['accuracy'])
else:
    # create the optimizer and learning rate scheduler
    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate = 1e-3,
        # decay_steps=train_steps_per_epoch * epochs,
        decay_steps=epochs+2,
        alpha=1e-2 
    )

    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    # create the U-net neural network
    model = TransUNet(image_size=target_size, freeze_enc_cnn=False, pretrain=True)
    model.compile(optimizer=optimizer, loss=dice_loss, metrics=['accuracy'])

    # set up model checkpoints and early stopping
    checkpoints_path = os.path.join('Checkpoints', 'model_weights.h5')
    model_checkpoint = ModelCheckpoint(filepath=checkpoints_path, save_best_only=True, monitor='val_loss')
    early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=8)

    # log the training to a .csv for reference
    csv_logger = CSVLogger('training_log.csv', append=True)

    history = model.fit(train_generator, validation_data=val_generator, steps_per_epoch=train_steps_per_epoch, validation_steps=val_steps_per_epoch, epochs=epochs, callbacks=[model_checkpoint, early_stopping, csv_logger])
2024-12-04 19:46:49.954136: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-12-04 19:46:49.956644: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
In [6]:
preds = []
ground_truths = []
num_samples = 24

# Generate predictions and ground truths
for i in range(num_samples):
    # Fetch a batch from the test generator
    batch = next(test_generator)
    image, mask, path = batch  # Assuming generator returns (images, masks)
    
    preds.append(model.predict(image))  # Predict using the model
    ground_truths.append(mask)

best_threshold = 0.99

# Apply the best threshold to all predictions
final_preds = [(pred >= best_threshold).astype(int) for pred in preds]

# Compute Dice loss for each prediction
for i in range(len(final_preds)):
    loss = dice_loss(ground_truths[i], final_preds[i])  # Assuming `dice_loss` is defined
    print(f"Image {i + 1}: Dice Loss = {loss:.4f}")
1/1 [==============================] - 3s 3s/step
1/1 [==============================] - 0s 407ms/step
1/1 [==============================] - 0s 454ms/step
1/1 [==============================] - 0s 431ms/step
1/1 [==============================] - 0s 387ms/step
1/1 [==============================] - 0s 393ms/step
1/1 [==============================] - 0s 337ms/step
1/1 [==============================] - 0s 376ms/step
1/1 [==============================] - 0s 462ms/step
1/1 [==============================] - 0s 383ms/step
1/1 [==============================] - 0s 413ms/step
1/1 [==============================] - 0s 377ms/step
1/1 [==============================] - 0s 373ms/step
1/1 [==============================] - 1s 737ms/step
1/1 [==============================] - 0s 316ms/step
1/1 [==============================] - 0s 439ms/step
1/1 [==============================] - 0s 408ms/step
1/1 [==============================] - 0s 436ms/step
1/1 [==============================] - 1s 512ms/step
1/1 [==============================] - 0s 365ms/step
1/1 [==============================] - 0s 345ms/step
1/1 [==============================] - 0s 373ms/step
1/1 [==============================] - 0s 480ms/step
1/1 [==============================] - 0s 366ms/step
Image 1: Dice Loss = 0.0626
Image 2: Dice Loss = 0.0000
Image 3: Dice Loss = 1.0000
Image 4: Dice Loss = 0.1234
Image 5: Dice Loss = 0.0000
Image 6: Dice Loss = 0.0000
Image 7: Dice Loss = 0.0000
Image 8: Dice Loss = 0.1158
Image 9: Dice Loss = 0.0000
Image 10: Dice Loss = 0.1349
Image 11: Dice Loss = 0.0000
Image 12: Dice Loss = 0.0708
Image 13: Dice Loss = 0.0000
Image 14: Dice Loss = 0.1931
Image 15: Dice Loss = 0.0000
Image 16: Dice Loss = 0.0601
Image 17: Dice Loss = 0.0468
Image 18: Dice Loss = 0.0000
Image 19: Dice Loss = 0.0000
Image 20: Dice Loss = 0.0000
Image 21: Dice Loss = 0.0000
Image 22: Dice Loss = 0.0000
Image 23: Dice Loss = 0.0000
Image 24: Dice Loss = 0.0000
In [7]:
def visualize_predictions(generator, model, num_samples=8, target_size=(224, 224)):
    """
    Visualize predictions vs. ground truths overlaid on original images.

    Parameters:
        generator (generator): Data generator for test data.
        model (Model): Trained segmentation model.
        num_samples (int): Number of samples to visualize.
        target_size (tuple): Target size for resizing (default=(224, 224)).
    """
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))

    for i in range(num_samples):
        # Fetch one image and mask from the generator
        image_batch, mask_batch, path = next(generator)
        image = image_batch[0]  # Single image
        ground_truth = mask_batch[0]  # Corresponding ground truth mask

        # Ensure image is RGB
        if len(image.shape) == 2:
            image = np.stack([image] * 3, axis=-1)  # Convert grayscale to RGB

        # Ensure ground truth is a single-channel binary mask
        if ground_truth.ndim == 3 and ground_truth.shape[-1] == 3:
            ground_truth = ground_truth[:, :, 0]  # Extract the first channel

        # Generate prediction
        raw_prediction = model.predict(image[np.newaxis])[0]  # Add batch dimension for prediction

        # Ensure prediction is single-channel
        if raw_prediction.ndim == 3 and raw_prediction.shape[-1] == 3:
            prediction = raw_prediction[:, :, 0]  # Extract the first channel
        else:
            prediction = raw_prediction
        prediction = (prediction >= 0.99).astype(np.uint8)  # Threshold prediction

        # Create overlays
        gt_overlay = image.copy()
        pred_overlay = image.copy()

        # Overlay ground truth in red
        gt_overlay[ground_truth == 1] = [255, 0, 0]

        # Overlay prediction in green
        pred_overlay[prediction == 1] = [0, 255, 0]

        # Plot original image, ground truth overlay, and prediction overlay
        axes[i, 0].imshow(image)
        axes[i, 0].set_title(f"Image {i + 1}")
        axes[i, 0].axis('off')

        axes[i, 1].imshow(gt_overlay)
        axes[i, 1].set_title(f"Ground Truth Overlay {i + 1}")
        axes[i, 1].axis('off')

        axes[i, 2].imshow(pred_overlay)
        axes[i, 2].set_title(f"Prediction Overlay {i + 1}")
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()


# Call the function with test generator and trained model
visualize_predictions(test_generator, model, num_samples=24)
1/1 [==============================] - 0s 466ms/step
1/1 [==============================] - 0s 401ms/step
1/1 [==============================] - 0s 391ms/step
1/1 [==============================] - 0s 380ms/step
1/1 [==============================] - 0s 389ms/step
1/1 [==============================] - 0s 407ms/step
1/1 [==============================] - 0s 401ms/step
1/1 [==============================] - 0s 334ms/step
1/1 [==============================] - 0s 434ms/step
1/1 [==============================] - 0s 427ms/step
1/1 [==============================] - 0s 368ms/step
1/1 [==============================] - 0s 379ms/step
1/1 [==============================] - 1s 550ms/step
1/1 [==============================] - 0s 342ms/step
1/1 [==============================] - 0s 402ms/step
1/1 [==============================] - 0s 454ms/step
1/1 [==============================] - 0s 396ms/step
1/1 [==============================] - 0s 433ms/step
1/1 [==============================] - 0s 421ms/step
1/1 [==============================] - 0s 333ms/step
1/1 [==============================] - 0s 417ms/step
1/1 [==============================] - 0s 398ms/step
1/1 [==============================] - 0s 344ms/step
1/1 [==============================] - 0s 390ms/step
No description has been provided for this image
In [8]:
def binary_to_rle(binary_mask):
    """
    Converts a binary mask to RLE (Run-Length Encoding).
    """
    
    # Flatten mask in column-major order
    flat_mask = binary_mask.T.flatten()

    rle = []
    start = -1
    for i, val in enumerate(flat_mask):
        if val == 1 and start == -1:
            start = i
        elif val == 0 and start != -1:
            rle.extend([start + 1, i - start])
            start = -1
    if start != -1:
        rle.extend([start + 1, len(flat_mask) - start])
    
    return ' '.join(map(str, rle))
In [9]:
def save_predictions_to_csv(test_generator, model, output_csv_path):
    """
    Generates predictions using the trained model and writes them to a CSV file in RLE format.
    
    Parameters:
        test_generator: The data generator for the test set.
        model: The trained segmentation model.
        output_csv_path: Path to save the CSV file.
    """
    
    with open(output_csv_path, mode='w', newline='') as csvfile:
        csv_writer = csv.writer(csvfile)
        csv_writer.writerow(['id', 'segmentation'])  # Header row

        i = 0
        for image, masks, ids in test_generator:
            i += 1
            if i > 81:
                break;
                
            predictions = model.predict(image)
            predictions = (predictions > 0.99).astype(int)  

            for pred_mask, mask_id in zip(predictions, ids):
                rle = binary_to_rle(pred_mask.squeeze())
                csv_writer.writerow([mask_id, rle])

            # Clear memory after each batch just in case
            del predictions, image, masks, ids
            gc.collect()
In [10]:
predictions_output_path = 'model_output.csv'
save_predictions_to_csv(test_generator, model, predictions_output_path)
1/1 [==============================] - 0s 489ms/step
1/1 [==============================] - 0s 448ms/step
1/1 [==============================] - 0s 394ms/step
1/1 [==============================] - 0s 447ms/step
1/1 [==============================] - 0s 379ms/step
1/1 [==============================] - 0s 465ms/step
1/1 [==============================] - 0s 384ms/step
1/1 [==============================] - 0s 342ms/step
1/1 [==============================] - 0s 457ms/step
1/1 [==============================] - 0s 360ms/step
1/1 [==============================] - 0s 389ms/step
1/1 [==============================] - 0s 457ms/step
1/1 [==============================] - 0s 479ms/step
1/1 [==============================] - 0s 495ms/step
1/1 [==============================] - 1s 519ms/step
1/1 [==============================] - 0s 430ms/step
1/1 [==============================] - 0s 459ms/step
1/1 [==============================] - 0s 409ms/step
1/1 [==============================] - 0s 499ms/step
1/1 [==============================] - 0s 390ms/step
1/1 [==============================] - 0s 380ms/step
1/1 [==============================] - 0s 468ms/step
1/1 [==============================] - 0s 465ms/step
1/1 [==============================] - 0s 400ms/step
1/1 [==============================] - 0s 372ms/step
1/1 [==============================] - 0s 486ms/step
1/1 [==============================] - 0s 342ms/step
1/1 [==============================] - 0s 446ms/step
1/1 [==============================] - 0s 461ms/step
1/1 [==============================] - 0s 394ms/step
1/1 [==============================] - 1s 517ms/step
1/1 [==============================] - 1s 574ms/step
1/1 [==============================] - 0s 417ms/step
1/1 [==============================] - 0s 413ms/step
1/1 [==============================] - 0s 422ms/step
1/1 [==============================] - 0s 452ms/step
1/1 [==============================] - 0s 469ms/step
1/1 [==============================] - 0s 485ms/step
1/1 [==============================] - 0s 476ms/step
1/1 [==============================] - 0s 429ms/step
1/1 [==============================] - 0s 394ms/step
1/1 [==============================] - 0s 395ms/step
1/1 [==============================] - 0s 429ms/step
1/1 [==============================] - 0s 360ms/step
1/1 [==============================] - 0s 388ms/step
1/1 [==============================] - 0s 479ms/step
1/1 [==============================] - 0s 441ms/step
1/1 [==============================] - 1s 502ms/step
1/1 [==============================] - 0s 425ms/step
1/1 [==============================] - 1s 711ms/step
1/1 [==============================] - 1s 530ms/step
1/1 [==============================] - 0s 439ms/step
1/1 [==============================] - 0s 451ms/step
1/1 [==============================] - 0s 436ms/step
1/1 [==============================] - 1s 551ms/step
1/1 [==============================] - 0s 394ms/step
1/1 [==============================] - 1s 528ms/step
1/1 [==============================] - 0s 386ms/step
1/1 [==============================] - 1s 598ms/step
1/1 [==============================] - 1s 588ms/step
1/1 [==============================] - 1s 554ms/step
1/1 [==============================] - 0s 481ms/step
1/1 [==============================] - 0s 354ms/step
1/1 [==============================] - 0s 497ms/step
1/1 [==============================] - 1s 664ms/step
1/1 [==============================] - 1s 787ms/step
1/1 [==============================] - 1s 884ms/step
1/1 [==============================] - 1s 1s/step
1/1 [==============================] - 1s 828ms/step
1/1 [==============================] - 1s 663ms/step
1/1 [==============================] - 1s 895ms/step
1/1 [==============================] - 1s 882ms/step
1/1 [==============================] - 1s 514ms/step
1/1 [==============================] - 0s 499ms/step
1/1 [==============================] - 0s 362ms/step
1/1 [==============================] - 0s 467ms/step
1/1 [==============================] - 0s 477ms/step
1/1 [==============================] - 0s 397ms/step
1/1 [==============================] - 0s 415ms/step
1/1 [==============================] - 0s 484ms/step
1/1 [==============================] - 0s 474ms/step