Switch to unified view

a b/GI-Tract-Image-Segmentation.py
1
""" Import statements and check for GPU """
2
3
import os
4
import re
5
import glob
6
import math
7
import cv2
8
import csv 
9
10
import pandas as pd
11
import matplotlib.pyplot as plt
12
import numpy as np
13
14
from sklearn.model_selection import train_test_split
15
from transunet import TransUNet
16
17
import tensorflow as tf
18
from tensorflow.keras.models import Model
19
from tensorflow.keras.optimizers import Adam
20
from tensorflow.keras.preprocessing.image import ImageDataGenerator
21
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
22
23
from tensorflow import keras
24
from tensorflow.keras import layers
25
26
# List available GPUs
27
gpus = tf.config.list_physical_devices('GPU')
28
print("GPUs: ", gpus)
29
30
if gpus:
31
    print("TensorFlow is using the GPU.")
32
else:
33
    print("TensorFlow is not using the GPU.")
34
35
36
37
38
39
""" Function Definitions """
40
41
def rle_to_binary(rle, shape):
42
    """
43
    Decodes run length encoded masks into a binary image
44
45
    Parameters:
46
        rle (list): list containing the starts and lengths that make up each RLE mask
47
        shape (tuple): the original shape of the associated image
48
    """
49
50
    # Initialize a flat mask with zeros
51
    mask = np.zeros(shape[0] * shape[1], dtype=np.uint8)
52
    
53
    if rle == '' or rle == '0':  # Handle empty RLE
54
        return mask.reshape(shape, order='C')
55
56
    # Decode RLE into mask
57
    rle_numbers = list(map(int, rle.split()))
58
    for i in range(0, len(rle_numbers), 2):
59
        start = rle_numbers[i] - 1  # Convert to zero-indexed
60
        length = rle_numbers[i + 1]
61
        mask[start:start + length] = 1
62
63
    # Reshape flat mask into 2D
64
    return mask.reshape(shape, order='C')
65
66
67
68
def custom_generator(gdf, dir, batch_size, target_size=(224, 224), test_mode=False):
69
    """
70
    Custom data generator that dynamically aligns images and masks using RLE decoding.
71
    
72
    Parameters:
73
        gdf (GroupBy): Grouped dataframe containing image IDs and RLEs.
74
        dir (str): Root directory of the dataset.
75
        batch_size (int): Number of samples per batch.
76
        target_size (tuple): Target size for resizing (default=(224, 224)).
77
        test_mode (bool): If True, yields one image and mask at a time.
78
    """
79
80
    ids = list(gdf.groups.keys())
81
    dir2 = 'train'
82
83
    while True:
84
        sample_ids = np.random.choice(ids, size=batch_size, replace=False)
85
        images, masks = [], []
86
87
        for id_num in sample_ids:
88
            # Get the dataframe rows for the current image
89
            img_rows = gdf.get_group(id_num)
90
            rle_list = img_rows['segmentation'].tolist()
91
            
92
            # Construct the file path for the image
93
            sections = id_num.split('_')
94
            case = sections[0]
95
            day = sections[0] + '_' + sections[1]
96
            slice_id = sections[2] + '_' + sections[3]
97
            
98
            pattern = os.path.join(dir, dir2, case, day, "scans", f"{slice_id}*.png")
99
            filelist = glob.glob(pattern)
100
            
101
            if filelist:
102
                file = filelist[0]
103
                image = cv2.imread(file, cv2.IMREAD_COLOR)
104
                if image is None:
105
                    print(f"Image not found: {file}")
106
                    continue  # Skip if the image is missing
107
                
108
                # Original shape of the image
109
                original_shape = image.shape[:2]
110
111
                # Resize the image
112
                resized_image = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
113
114
                # Decode and resize the masks
115
                mask = np.zeros((target_size[0], target_size[1], len(rle_list)), dtype=np.uint8)
116
                for i, rle in enumerate(rle_list):
117
                    if rle != '0':  # Check if the RLE is valid
118
                        decoded_mask = rle_to_binary(rle, original_shape)
119
                        resized_mask = cv2.resize(decoded_mask, target_size, interpolation=cv2.INTER_NEAREST)
120
                        mask[:, :, i] = resized_mask
121
122
                if test_mode:
123
                    # Return individual samples in test mode
124
                    yield resized_image[np.newaxis], mask[np.newaxis], pattern
125
                else:
126
                    images.append(resized_image)
127
                    masks.append(mask)
128
129
        if not test_mode:
130
            x = np.array(images)
131
            y = np.array(masks)
132
            yield x, y, None
133
134
       
135
136
137
138
""" Loss function: dice loss ignores negative class thus negating class imbalance issues """     
139
140
def dice_coef(y_true, y_pred, smooth=1e-6):
141
    # Ensure consistent data types
142
    y_true = tf.cast(y_true, tf.float32)
143
    y_pred = tf.cast(y_pred, tf.float32)
144
145
    y_true_f = tf.keras.backend.flatten(y_true)
146
    y_pred_f = tf.keras.backend.flatten(y_pred)
147
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
148
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
149
150
def dice_loss(y_true, y_pred):
151
    y_true = tf.cast(y_true, tf.float32)
152
    y_pred = tf.cast(y_pred, tf.float32)
153
    return 1 - dice_coef(y_true, y_pred)
154
155
156
157
158
159
""" Construct pipeline """
160
161
# dir = '../path/Dataset'
162
dir = './Dataset'
163
164
target_size = 224
165
batch_size = 24
166
epochs = 124
167
168
# read the csv file into a dataframe. os.path.join makes code executable across operating systes
169
df = pd.read_csv(os.path.join('.', dir, 'train.csv'))
170
df['segmentation'] = df['segmentation'].fillna('0')
171
172
# split into training, testing and validation sets
173
train_ids, temp_ids = train_test_split(df.id.unique(), test_size=0.25, random_state=42)
174
val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)
175
176
# convert dfs into groupby objects to make sure rows are grouped by id
177
train_grouped_df = df[df.id.isin(train_ids)].groupby('id')
178
val_grouped_df = df[df.id.isin(val_ids)].groupby('id')
179
test_grouped_df = df[df.id.isin(test_ids)].groupby('id')
180
181
182
# steps per epoch is typically train length / batch size to use all training examples
183
train_steps_per_epoch = math.ceil(len(train_ids) / batch_size)
184
val_steps_per_epoch = math.ceil(len(val_ids) / batch_size)
185
test_steps_per_epoch = math.ceil(len(test_ids) / batch_size)
186
187
# create the training and validation datagens
188
train_generator = custom_generator(train_grouped_df, dir, batch_size, (target_size, target_size))
189
val_generator = custom_generator(val_grouped_df, dir, batch_size, (target_size, target_size))
190
test_generator = custom_generator(test_grouped_df, dir, batch_size, (target_size, target_size), test_mode=True)
191
192
193
194
195
196
""" Build the model or load the trained model """
197
198
loading = True
199
200
if loading:
201
    weights_path = './impmodels/model_weights.h5'
202
    model = TransUNet(image_size=224, pretrain=False)
203
    model.load_weights(weights_path)
204
    model.compile(optimizer='adam', loss=dice_loss, metrics=['accuracy'])
205
else:
206
    # create the optimizer and learning rate scheduler
207
    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
208
        initial_learning_rate = 1e-3,
209
        # decay_steps=train_steps_per_epoch * epochs,
210
        decay_steps=epochs+2,
211
        alpha=1e-2 
212
    )
213
214
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
215
216
    # create the U-net neural network
217
    model = TransUNet(image_size=target_size, freeze_enc_cnn=False, pretrain=True)
218
    model.compile(optimizer=optimizer, loss=dice_loss, metrics=['accuracy'])
219
220
    # set up model checkpoints and early stopping
221
    checkpoints_path = os.path.join('Checkpoints', 'model_weights.h5')
222
    model_checkpoint = ModelCheckpoint(filepath=checkpoints_path, save_best_only=True, monitor='val_loss')
223
    early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=8)
224
225
    # log the training to a .csv for reference
226
    csv_logger = CSVLogger('training_log.csv', append=True)
227
228
    history = model.fit(train_generator, validation_data=val_generator, steps_per_epoch=train_steps_per_epoch, validation_steps=val_steps_per_epoch, epochs=epochs, callbacks=[model_checkpoint, early_stopping, csv_logger])
229
230
231
232
233
234
""" Display some predictions """
235
236
preds = []
237
ground_truths = []
238
num_samples = 50
239
240
# Generate predictions and ground truths
241
for i in range(num_samples):
242
    # Fetch a batch from the test generator
243
    batch = next(test_generator)
244
    image, mask = batch  
245
    
246
    preds.append(model.predict(image))  # Predict using the model
247
    ground_truths.append(mask)
248
249
best_threshold = 0.99
250
251
# Apply the best threshold to all predictions
252
final_preds = [(pred >= best_threshold).astype(int) for pred in preds]
253
254
# Compute Dice loss for each prediction
255
for i in range(len(final_preds)):
256
    loss = dice_loss(ground_truths[i], final_preds[i])  
257
    print(f"Image {i + 1}: Dice Loss = {loss:.4f}")
258
259
260
261
def visualize_predictions(generator, model, num_samples=8, target_size=(224, 224)):
262
    """
263
    Visualize predictions vs. ground truths overlaid on original images.
264
265
    Parameters:
266
        generator (generator): Data generator
267
        model (Model): Trained segmentation model
268
        num_samples (int): Number of samples to visualize
269
        target_size (tuple): Target size for resizing (default=(224, 224)).
270
    """
271
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
272
273
    for i in range(num_samples):
274
        # Fetch one image and mask from the generator
275
        image_batch, mask_batch = next(generator)
276
        image = image_batch[0]  # Single image
277
        ground_truth = mask_batch[0]  # Corresponding ground truth mask
278
279
        # Ensure image is RGB
280
        if len(image.shape) == 2:
281
            image = np.stack([image] * 3, axis=-1)  # Convert grayscale to RGB
282
283
        # Ensure ground truth is a single-channel binary mask
284
        if ground_truth.ndim == 3 and ground_truth.shape[-1] == 3:
285
            ground_truth = ground_truth[:, :, 0]  # Extract the first channel
286
287
        # Generate prediction
288
        raw_prediction = model.predict(image[np.newaxis])[0]  # Add batch dimension for prediction
289
290
        # Ensure prediction is single-channel
291
        if raw_prediction.ndim == 3 and raw_prediction.shape[-1] == 3:
292
            prediction = raw_prediction[:, :, 0]  # Extract the first channel
293
        else:
294
            prediction = raw_prediction
295
        prediction = (prediction >= 0.99).astype(np.uint8)  # Threshold prediction
296
297
        # Create overlays
298
        gt_overlay = image.copy()
299
        pred_overlay = image.copy()
300
301
        # Overlay ground truth in red
302
        gt_overlay[ground_truth == 1] = [255, 0, 0]
303
304
        # Overlay prediction in green
305
        pred_overlay[prediction == 1] = [0, 255, 0]
306
307
        # Plot original image, ground truth overlay, and prediction overlay
308
        axes[i, 0].imshow(image)
309
        axes[i, 0].set_title(f"Image {i + 1}")
310
        axes[i, 0].axis('off')
311
312
        axes[i, 1].imshow(gt_overlay)
313
        axes[i, 1].set_title(f"Ground Truth Overlay {i + 1}")
314
        axes[i, 1].axis('off')
315
316
        axes[i, 2].imshow(pred_overlay)
317
        axes[i, 2].set_title(f"Prediction Overlay {i + 1}")
318
        axes[i, 2].axis('off')
319
320
    plt.tight_layout()
321
    plt.show()
322
323
324
# Call the function with your test generator and trained model
325
visualize_predictions(test_generator, model, num_samples=24)
326
327
328
329
330
331
def binary_to_rle(binary_mask):
332
    """
333
    Converts a binary mask to RLE (Run-Length Encoding).
334
    """
335
    # Flatten mask in column-major order
336
    flat_mask = binary_mask.T.flatten()
337
338
    rle = []
339
    start = -1
340
    for i, val in enumerate(flat_mask):
341
        if val == 1 and start == -1:
342
            start = i
343
        elif val == 0 and start != -1:
344
            rle.extend([start + 1, i - start])
345
            start = -1
346
    if start != -1:
347
        rle.extend([start + 1, len(flat_mask) - start])
348
    
349
    return ' '.join(map(str, rle))
350
351
352
353
def save_predictions_to_csv(test_generator, model, output_csv_path):
354
    """
355
    Generates predictions using the trained model and writes them to a CSV file in RLE format.
356
    
357
    Parameters:
358
        test_generator: The data generator for the test set.
359
        model: The trained segmentation model.
360
        output_csv_path: Path to save the CSV file.
361
    """
362
    
363
    with open(output_csv_path, mode='w', newline='') as csvfile:
364
        csv_writer = csv.writer(csvfile)
365
        csv_writer.writerow(['id', 'segmentation'])  # Header row
366
367
        for image, masks, ids in test_generator:
368
            predictions = model.predict(image)
369
            predictions = (predictions > 0.99).astype(int)  
370
371
            for pred_mask, mask_id in zip(predictions, ids):
372
                rle = binary_to_rle(pred_mask.squeeze())
373
                csv_writer.writerow([mask_id, rle])
374
375
            print(f"Processed {len(ids)} predictions...")
376
377
378
379
save_predictions_to_csv(test_generator, model, 'model_output.csv')