|
a |
|
b/GI-Tract-Image-Segmentation.py |
|
|
1 |
""" Import statements and check for GPU """ |
|
|
2 |
|
|
|
3 |
import os |
|
|
4 |
import re |
|
|
5 |
import glob |
|
|
6 |
import math |
|
|
7 |
import cv2 |
|
|
8 |
import csv |
|
|
9 |
|
|
|
10 |
import pandas as pd |
|
|
11 |
import matplotlib.pyplot as plt |
|
|
12 |
import numpy as np |
|
|
13 |
|
|
|
14 |
from sklearn.model_selection import train_test_split |
|
|
15 |
from transunet import TransUNet |
|
|
16 |
|
|
|
17 |
import tensorflow as tf |
|
|
18 |
from tensorflow.keras.models import Model |
|
|
19 |
from tensorflow.keras.optimizers import Adam |
|
|
20 |
from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
|
21 |
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger |
|
|
22 |
|
|
|
23 |
from tensorflow import keras |
|
|
24 |
from tensorflow.keras import layers |
|
|
25 |
|
|
|
26 |
# List available GPUs |
|
|
27 |
gpus = tf.config.list_physical_devices('GPU') |
|
|
28 |
print("GPUs: ", gpus) |
|
|
29 |
|
|
|
30 |
if gpus: |
|
|
31 |
print("TensorFlow is using the GPU.") |
|
|
32 |
else: |
|
|
33 |
print("TensorFlow is not using the GPU.") |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
|
|
|
37 |
|
|
|
38 |
|
|
|
39 |
""" Function Definitions """ |
|
|
40 |
|
|
|
41 |
def rle_to_binary(rle, shape): |
|
|
42 |
""" |
|
|
43 |
Decodes run length encoded masks into a binary image |
|
|
44 |
|
|
|
45 |
Parameters: |
|
|
46 |
rle (list): list containing the starts and lengths that make up each RLE mask |
|
|
47 |
shape (tuple): the original shape of the associated image |
|
|
48 |
""" |
|
|
49 |
|
|
|
50 |
# Initialize a flat mask with zeros |
|
|
51 |
mask = np.zeros(shape[0] * shape[1], dtype=np.uint8) |
|
|
52 |
|
|
|
53 |
if rle == '' or rle == '0': # Handle empty RLE |
|
|
54 |
return mask.reshape(shape, order='C') |
|
|
55 |
|
|
|
56 |
# Decode RLE into mask |
|
|
57 |
rle_numbers = list(map(int, rle.split())) |
|
|
58 |
for i in range(0, len(rle_numbers), 2): |
|
|
59 |
start = rle_numbers[i] - 1 # Convert to zero-indexed |
|
|
60 |
length = rle_numbers[i + 1] |
|
|
61 |
mask[start:start + length] = 1 |
|
|
62 |
|
|
|
63 |
# Reshape flat mask into 2D |
|
|
64 |
return mask.reshape(shape, order='C') |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
|
|
|
68 |
def custom_generator(gdf, dir, batch_size, target_size=(224, 224), test_mode=False): |
|
|
69 |
""" |
|
|
70 |
Custom data generator that dynamically aligns images and masks using RLE decoding. |
|
|
71 |
|
|
|
72 |
Parameters: |
|
|
73 |
gdf (GroupBy): Grouped dataframe containing image IDs and RLEs. |
|
|
74 |
dir (str): Root directory of the dataset. |
|
|
75 |
batch_size (int): Number of samples per batch. |
|
|
76 |
target_size (tuple): Target size for resizing (default=(224, 224)). |
|
|
77 |
test_mode (bool): If True, yields one image and mask at a time. |
|
|
78 |
""" |
|
|
79 |
|
|
|
80 |
ids = list(gdf.groups.keys()) |
|
|
81 |
dir2 = 'train' |
|
|
82 |
|
|
|
83 |
while True: |
|
|
84 |
sample_ids = np.random.choice(ids, size=batch_size, replace=False) |
|
|
85 |
images, masks = [], [] |
|
|
86 |
|
|
|
87 |
for id_num in sample_ids: |
|
|
88 |
# Get the dataframe rows for the current image |
|
|
89 |
img_rows = gdf.get_group(id_num) |
|
|
90 |
rle_list = img_rows['segmentation'].tolist() |
|
|
91 |
|
|
|
92 |
# Construct the file path for the image |
|
|
93 |
sections = id_num.split('_') |
|
|
94 |
case = sections[0] |
|
|
95 |
day = sections[0] + '_' + sections[1] |
|
|
96 |
slice_id = sections[2] + '_' + sections[3] |
|
|
97 |
|
|
|
98 |
pattern = os.path.join(dir, dir2, case, day, "scans", f"{slice_id}*.png") |
|
|
99 |
filelist = glob.glob(pattern) |
|
|
100 |
|
|
|
101 |
if filelist: |
|
|
102 |
file = filelist[0] |
|
|
103 |
image = cv2.imread(file, cv2.IMREAD_COLOR) |
|
|
104 |
if image is None: |
|
|
105 |
print(f"Image not found: {file}") |
|
|
106 |
continue # Skip if the image is missing |
|
|
107 |
|
|
|
108 |
# Original shape of the image |
|
|
109 |
original_shape = image.shape[:2] |
|
|
110 |
|
|
|
111 |
# Resize the image |
|
|
112 |
resized_image = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR) |
|
|
113 |
|
|
|
114 |
# Decode and resize the masks |
|
|
115 |
mask = np.zeros((target_size[0], target_size[1], len(rle_list)), dtype=np.uint8) |
|
|
116 |
for i, rle in enumerate(rle_list): |
|
|
117 |
if rle != '0': # Check if the RLE is valid |
|
|
118 |
decoded_mask = rle_to_binary(rle, original_shape) |
|
|
119 |
resized_mask = cv2.resize(decoded_mask, target_size, interpolation=cv2.INTER_NEAREST) |
|
|
120 |
mask[:, :, i] = resized_mask |
|
|
121 |
|
|
|
122 |
if test_mode: |
|
|
123 |
# Return individual samples in test mode |
|
|
124 |
yield resized_image[np.newaxis], mask[np.newaxis], pattern |
|
|
125 |
else: |
|
|
126 |
images.append(resized_image) |
|
|
127 |
masks.append(mask) |
|
|
128 |
|
|
|
129 |
if not test_mode: |
|
|
130 |
x = np.array(images) |
|
|
131 |
y = np.array(masks) |
|
|
132 |
yield x, y, None |
|
|
133 |
|
|
|
134 |
|
|
|
135 |
|
|
|
136 |
|
|
|
137 |
|
|
|
138 |
""" Loss function: dice loss ignores negative class thus negating class imbalance issues """ |
|
|
139 |
|
|
|
140 |
def dice_coef(y_true, y_pred, smooth=1e-6): |
|
|
141 |
# Ensure consistent data types |
|
|
142 |
y_true = tf.cast(y_true, tf.float32) |
|
|
143 |
y_pred = tf.cast(y_pred, tf.float32) |
|
|
144 |
|
|
|
145 |
y_true_f = tf.keras.backend.flatten(y_true) |
|
|
146 |
y_pred_f = tf.keras.backend.flatten(y_pred) |
|
|
147 |
intersection = tf.reduce_sum(y_true_f * y_pred_f) |
|
|
148 |
return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth) |
|
|
149 |
|
|
|
150 |
def dice_loss(y_true, y_pred): |
|
|
151 |
y_true = tf.cast(y_true, tf.float32) |
|
|
152 |
y_pred = tf.cast(y_pred, tf.float32) |
|
|
153 |
return 1 - dice_coef(y_true, y_pred) |
|
|
154 |
|
|
|
155 |
|
|
|
156 |
|
|
|
157 |
|
|
|
158 |
|
|
|
159 |
""" Construct pipeline """ |
|
|
160 |
|
|
|
161 |
# dir = '../path/Dataset' |
|
|
162 |
dir = './Dataset' |
|
|
163 |
|
|
|
164 |
target_size = 224 |
|
|
165 |
batch_size = 24 |
|
|
166 |
epochs = 124 |
|
|
167 |
|
|
|
168 |
# read the csv file into a dataframe. os.path.join makes code executable across operating systes |
|
|
169 |
df = pd.read_csv(os.path.join('.', dir, 'train.csv')) |
|
|
170 |
df['segmentation'] = df['segmentation'].fillna('0') |
|
|
171 |
|
|
|
172 |
# split into training, testing and validation sets |
|
|
173 |
train_ids, temp_ids = train_test_split(df.id.unique(), test_size=0.25, random_state=42) |
|
|
174 |
val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42) |
|
|
175 |
|
|
|
176 |
# convert dfs into groupby objects to make sure rows are grouped by id |
|
|
177 |
train_grouped_df = df[df.id.isin(train_ids)].groupby('id') |
|
|
178 |
val_grouped_df = df[df.id.isin(val_ids)].groupby('id') |
|
|
179 |
test_grouped_df = df[df.id.isin(test_ids)].groupby('id') |
|
|
180 |
|
|
|
181 |
|
|
|
182 |
# steps per epoch is typically train length / batch size to use all training examples |
|
|
183 |
train_steps_per_epoch = math.ceil(len(train_ids) / batch_size) |
|
|
184 |
val_steps_per_epoch = math.ceil(len(val_ids) / batch_size) |
|
|
185 |
test_steps_per_epoch = math.ceil(len(test_ids) / batch_size) |
|
|
186 |
|
|
|
187 |
# create the training and validation datagens |
|
|
188 |
train_generator = custom_generator(train_grouped_df, dir, batch_size, (target_size, target_size)) |
|
|
189 |
val_generator = custom_generator(val_grouped_df, dir, batch_size, (target_size, target_size)) |
|
|
190 |
test_generator = custom_generator(test_grouped_df, dir, batch_size, (target_size, target_size), test_mode=True) |
|
|
191 |
|
|
|
192 |
|
|
|
193 |
|
|
|
194 |
|
|
|
195 |
|
|
|
196 |
""" Build the model or load the trained model """ |
|
|
197 |
|
|
|
198 |
loading = True |
|
|
199 |
|
|
|
200 |
if loading: |
|
|
201 |
weights_path = './impmodels/model_weights.h5' |
|
|
202 |
model = TransUNet(image_size=224, pretrain=False) |
|
|
203 |
model.load_weights(weights_path) |
|
|
204 |
model.compile(optimizer='adam', loss=dice_loss, metrics=['accuracy']) |
|
|
205 |
else: |
|
|
206 |
# create the optimizer and learning rate scheduler |
|
|
207 |
lr_schedule = tf.keras.optimizers.schedules.CosineDecay( |
|
|
208 |
initial_learning_rate = 1e-3, |
|
|
209 |
# decay_steps=train_steps_per_epoch * epochs, |
|
|
210 |
decay_steps=epochs+2, |
|
|
211 |
alpha=1e-2 |
|
|
212 |
) |
|
|
213 |
|
|
|
214 |
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) |
|
|
215 |
|
|
|
216 |
# create the U-net neural network |
|
|
217 |
model = TransUNet(image_size=target_size, freeze_enc_cnn=False, pretrain=True) |
|
|
218 |
model.compile(optimizer=optimizer, loss=dice_loss, metrics=['accuracy']) |
|
|
219 |
|
|
|
220 |
# set up model checkpoints and early stopping |
|
|
221 |
checkpoints_path = os.path.join('Checkpoints', 'model_weights.h5') |
|
|
222 |
model_checkpoint = ModelCheckpoint(filepath=checkpoints_path, save_best_only=True, monitor='val_loss') |
|
|
223 |
early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=8) |
|
|
224 |
|
|
|
225 |
# log the training to a .csv for reference |
|
|
226 |
csv_logger = CSVLogger('training_log.csv', append=True) |
|
|
227 |
|
|
|
228 |
history = model.fit(train_generator, validation_data=val_generator, steps_per_epoch=train_steps_per_epoch, validation_steps=val_steps_per_epoch, epochs=epochs, callbacks=[model_checkpoint, early_stopping, csv_logger]) |
|
|
229 |
|
|
|
230 |
|
|
|
231 |
|
|
|
232 |
|
|
|
233 |
|
|
|
234 |
""" Display some predictions """ |
|
|
235 |
|
|
|
236 |
preds = [] |
|
|
237 |
ground_truths = [] |
|
|
238 |
num_samples = 50 |
|
|
239 |
|
|
|
240 |
# Generate predictions and ground truths |
|
|
241 |
for i in range(num_samples): |
|
|
242 |
# Fetch a batch from the test generator |
|
|
243 |
batch = next(test_generator) |
|
|
244 |
image, mask = batch |
|
|
245 |
|
|
|
246 |
preds.append(model.predict(image)) # Predict using the model |
|
|
247 |
ground_truths.append(mask) |
|
|
248 |
|
|
|
249 |
best_threshold = 0.99 |
|
|
250 |
|
|
|
251 |
# Apply the best threshold to all predictions |
|
|
252 |
final_preds = [(pred >= best_threshold).astype(int) for pred in preds] |
|
|
253 |
|
|
|
254 |
# Compute Dice loss for each prediction |
|
|
255 |
for i in range(len(final_preds)): |
|
|
256 |
loss = dice_loss(ground_truths[i], final_preds[i]) |
|
|
257 |
print(f"Image {i + 1}: Dice Loss = {loss:.4f}") |
|
|
258 |
|
|
|
259 |
|
|
|
260 |
|
|
|
261 |
def visualize_predictions(generator, model, num_samples=8, target_size=(224, 224)): |
|
|
262 |
""" |
|
|
263 |
Visualize predictions vs. ground truths overlaid on original images. |
|
|
264 |
|
|
|
265 |
Parameters: |
|
|
266 |
generator (generator): Data generator |
|
|
267 |
model (Model): Trained segmentation model |
|
|
268 |
num_samples (int): Number of samples to visualize |
|
|
269 |
target_size (tuple): Target size for resizing (default=(224, 224)). |
|
|
270 |
""" |
|
|
271 |
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples)) |
|
|
272 |
|
|
|
273 |
for i in range(num_samples): |
|
|
274 |
# Fetch one image and mask from the generator |
|
|
275 |
image_batch, mask_batch = next(generator) |
|
|
276 |
image = image_batch[0] # Single image |
|
|
277 |
ground_truth = mask_batch[0] # Corresponding ground truth mask |
|
|
278 |
|
|
|
279 |
# Ensure image is RGB |
|
|
280 |
if len(image.shape) == 2: |
|
|
281 |
image = np.stack([image] * 3, axis=-1) # Convert grayscale to RGB |
|
|
282 |
|
|
|
283 |
# Ensure ground truth is a single-channel binary mask |
|
|
284 |
if ground_truth.ndim == 3 and ground_truth.shape[-1] == 3: |
|
|
285 |
ground_truth = ground_truth[:, :, 0] # Extract the first channel |
|
|
286 |
|
|
|
287 |
# Generate prediction |
|
|
288 |
raw_prediction = model.predict(image[np.newaxis])[0] # Add batch dimension for prediction |
|
|
289 |
|
|
|
290 |
# Ensure prediction is single-channel |
|
|
291 |
if raw_prediction.ndim == 3 and raw_prediction.shape[-1] == 3: |
|
|
292 |
prediction = raw_prediction[:, :, 0] # Extract the first channel |
|
|
293 |
else: |
|
|
294 |
prediction = raw_prediction |
|
|
295 |
prediction = (prediction >= 0.99).astype(np.uint8) # Threshold prediction |
|
|
296 |
|
|
|
297 |
# Create overlays |
|
|
298 |
gt_overlay = image.copy() |
|
|
299 |
pred_overlay = image.copy() |
|
|
300 |
|
|
|
301 |
# Overlay ground truth in red |
|
|
302 |
gt_overlay[ground_truth == 1] = [255, 0, 0] |
|
|
303 |
|
|
|
304 |
# Overlay prediction in green |
|
|
305 |
pred_overlay[prediction == 1] = [0, 255, 0] |
|
|
306 |
|
|
|
307 |
# Plot original image, ground truth overlay, and prediction overlay |
|
|
308 |
axes[i, 0].imshow(image) |
|
|
309 |
axes[i, 0].set_title(f"Image {i + 1}") |
|
|
310 |
axes[i, 0].axis('off') |
|
|
311 |
|
|
|
312 |
axes[i, 1].imshow(gt_overlay) |
|
|
313 |
axes[i, 1].set_title(f"Ground Truth Overlay {i + 1}") |
|
|
314 |
axes[i, 1].axis('off') |
|
|
315 |
|
|
|
316 |
axes[i, 2].imshow(pred_overlay) |
|
|
317 |
axes[i, 2].set_title(f"Prediction Overlay {i + 1}") |
|
|
318 |
axes[i, 2].axis('off') |
|
|
319 |
|
|
|
320 |
plt.tight_layout() |
|
|
321 |
plt.show() |
|
|
322 |
|
|
|
323 |
|
|
|
324 |
# Call the function with your test generator and trained model |
|
|
325 |
visualize_predictions(test_generator, model, num_samples=24) |
|
|
326 |
|
|
|
327 |
|
|
|
328 |
|
|
|
329 |
|
|
|
330 |
|
|
|
331 |
def binary_to_rle(binary_mask): |
|
|
332 |
""" |
|
|
333 |
Converts a binary mask to RLE (Run-Length Encoding). |
|
|
334 |
""" |
|
|
335 |
# Flatten mask in column-major order |
|
|
336 |
flat_mask = binary_mask.T.flatten() |
|
|
337 |
|
|
|
338 |
rle = [] |
|
|
339 |
start = -1 |
|
|
340 |
for i, val in enumerate(flat_mask): |
|
|
341 |
if val == 1 and start == -1: |
|
|
342 |
start = i |
|
|
343 |
elif val == 0 and start != -1: |
|
|
344 |
rle.extend([start + 1, i - start]) |
|
|
345 |
start = -1 |
|
|
346 |
if start != -1: |
|
|
347 |
rle.extend([start + 1, len(flat_mask) - start]) |
|
|
348 |
|
|
|
349 |
return ' '.join(map(str, rle)) |
|
|
350 |
|
|
|
351 |
|
|
|
352 |
|
|
|
353 |
def save_predictions_to_csv(test_generator, model, output_csv_path): |
|
|
354 |
""" |
|
|
355 |
Generates predictions using the trained model and writes them to a CSV file in RLE format. |
|
|
356 |
|
|
|
357 |
Parameters: |
|
|
358 |
test_generator: The data generator for the test set. |
|
|
359 |
model: The trained segmentation model. |
|
|
360 |
output_csv_path: Path to save the CSV file. |
|
|
361 |
""" |
|
|
362 |
|
|
|
363 |
with open(output_csv_path, mode='w', newline='') as csvfile: |
|
|
364 |
csv_writer = csv.writer(csvfile) |
|
|
365 |
csv_writer.writerow(['id', 'segmentation']) # Header row |
|
|
366 |
|
|
|
367 |
for image, masks, ids in test_generator: |
|
|
368 |
predictions = model.predict(image) |
|
|
369 |
predictions = (predictions > 0.99).astype(int) |
|
|
370 |
|
|
|
371 |
for pred_mask, mask_id in zip(predictions, ids): |
|
|
372 |
rle = binary_to_rle(pred_mask.squeeze()) |
|
|
373 |
csv_writer.writerow([mask_id, rle]) |
|
|
374 |
|
|
|
375 |
print(f"Processed {len(ids)} predictions...") |
|
|
376 |
|
|
|
377 |
|
|
|
378 |
|
|
|
379 |
save_predictions_to_csv(test_generator, model, 'model_output.csv') |