[e50482]: / train_2d.py

Download this file

546 lines (494 with data), 24.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
import glob
import shutil
import argparse
import monai
import torch
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from torchmetrics.functional import iou, dice_score
from torch import nn
from pytorch_lightning import Trainer
import pytorch_lightning as pl
import numpy as np
from torch.utils.data import DataLoader
from sklearn.model_selection import KFold
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
import datetime
import monai.transforms as mt
from torchmetrics import IoU, F1
from attn_unet_2d import AttU_Net2D
from unet_2d import Unet_2d
from data_2d import train_loader_ACDC, val_loader_ACDC
from monai.losses.dice import DiceLoss, DiceCELoss
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator as ea
import os
import matplotlib.pyplot as plt
from pathlib import Path
# Manual seeding
torch.manual_seed(42)
"""-----------------------Arguments-----------------------"""
parser = argparse.ArgumentParser(description='Training of UNet2D Segmentation')
parser.add_argument("--model_choice", default="UNet2D_Attention", type=str)
parser.add_argument("--kfolds", default=5, type=int)
parser.add_argument("--Loss_choice", default="dice", type=str)
parser.add_argument("--Batch_size_train", default=10, type=int)
parser.add_argument("--Batch_size_val", default=1, type=int)
parser.add_argument("--lr", default=np.float32(0.0005), type=float)
parser.add_argument("--lr_decay", default=np.float32(0.985), type=float)
parser.add_argument("--gpus", default=1, type=int)
parser.add_argument("--maximum_epochs", default=400, type=int)
parser.add_argument("--patience_early_stop", default=400, type=int)
parser.add_argument('--optimizer_choice', default='adam', type=str)
parser.add_argument('--scheduler_choice', default='plateau', type=str)
parser.add_argument('--dropout_rate', default=0.3, type=float)
"""--------------Models, Hyperparameters, Metrics and Variables--------------"""
arguments = parser.parse_args()
model_choice = arguments.model_choice
k_folds = arguments.kfolds
loss_choice = arguments.Loss_choice
batch_size_train = arguments.Batch_size_train
batch_size_val = arguments.Batch_size_val
learning_rate = arguments.lr
LR_decay_rate = arguments.lr_decay
dev = arguments.gpus
max_epochs = arguments.maximum_epochs
patience = arguments.patience_early_stop
optim_choice = arguments.optimizer_choice
scheduler_choice = arguments.scheduler_choice
drop_rate = arguments.dropout_rate
print("Model Choice:", model_choice,
"Dropout Rate", drop_rate,
"K Folds:", k_folds,
"Loss Choice:", loss_choice,
"LR Decay Rate:", LR_decay_rate,
"Device:", dev,
"Patience Early Stopping:", patience,
"Optimizer:", optim_choice,
"Scheduler:", scheduler_choice)
# Model
if model_choice == "UNet2D":
my_model = Unet_2d(drop=drop_rate).cuda() # with upsample --> Has more parameters
elif model_choice == "UNet2D_Attention":
my_model = AttU_Net2D(drop=drop_rate).cuda() # with Attention --> Has the most parameters
else:
raise ValueError("Wrong model choice!")
# Loss Function
if loss_choice == "dice_ce":
loss_func = DiceCELoss(include_background=True,
to_onehot_y=True,
sigmoid=False,
softmax=True,
jaccard=False,
reduction="mean",
smooth_nr=1e-05,
smooth_dr=1e-05,
# ce_weight=class_weights,
batch=False).cuda()
elif loss_choice == "dice":
loss_func = DiceLoss(include_background=True,
to_onehot_y=True,
sigmoid=False,
softmax=True,
jaccard=False,
reduction="mean",
smooth_nr=1e-05,
smooth_dr=1e-05,
batch=False).cuda()
else:
raise ValueError("Wrong loss choice!")
# IoU
IOU_metric = IoU(num_classes=4, absent_score=-1., reduction="none").cuda()
# F1 score
f1_metric = F1(num_classes=4, mdmc_average="samplewise", average='none').cuda()
# Softmax
soft = torch.nn.Softmax(dim=1)
# Required dimensions
tar_shape = [300, 300]
crop_shape = [224, 224]
"""---------Post Processing---------"""
keep_largest = monai.transforms.KeepLargestConnectedComponent(applied_labels=[0, 1, 2, 3])
"""---------Augmentations---------"""
train_transform = mt.Compose(
[mt.ResizeWithPadOrCropD(keys=["image", "mask"], spatial_size=tar_shape, mode="constant"),
mt.RandSpatialCropD(keys=["image", "mask"], roi_size=crop_shape, random_center=True, random_size=False),
mt.Rand2DElasticD(
keys=["image", "mask"],
prob=0.25,
spacing=(50, 50),
magnitude_range=(1, 3),
rotate_range=(np.pi / 4,),
scale_range=(0.1, 0.1),
translate_range=(10, 10),
padding_mode="border",
),
# mt.RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
mt.RandFlipd(["image", "mask"], spatial_axis=[0], prob=0.5),
mt.RandFlipd(["image", "mask"], spatial_axis=[1], prob=0.5),
mt.RandRotateD(keys=["image", "mask"], range_x=np.pi / 4, range_y=np.pi / 4, range_z=0.0, prob=0.50,
keep_size=True, mode=("nearest", "nearest"), align_corners=False),
mt.RandRotate90D(keys=["image", "mask"], prob=0.25, spatial_axes=(0, 1)),
mt.RandGaussianNoiseD(keys=["image"], prob=0.15, std=0.01),
mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False),
mt.RandZoomd(
keys=["image", "mask"],
min_zoom=0.9,
max_zoom=1.2,
mode="nearest",
align_corners=None,
prob=0.25,
),
mt.RandKSpaceSpikeNoiseD(keys=["image"], prob=0.15, intensity_range=(5.0, 7.5)),
]
)
val_transform = mt.Compose([
mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False)
])
"""------------------Datasets and Directories to save the results------------------"""
# Define the K-fold Cross Validator
splits = KFold(n_splits=k_folds, shuffle=True, random_state=42)
# train + val dataset for 5 fold cross validation training
concatenated_dataset = train_loader_ACDC(transform=None, train_index=None)
# paths to store the checkpoints
if not os.path.exists(r"../unet/checkpoints"):
os.makedirs(r"../unet/checkpoints")
checkpoint_path = "../unet/checkpoints"
if not os.path.exists(r"../unet/tb_logs"):
os.makedirs(r"../unet/tb_logs")
tb_path = "../unet/tb_logs"
if not os.path.exists(r"../unet/csv_logs"):
os.makedirs(r"../unet/csv_logs")
csv_path = "../unet/csv_logs"
# Temporarily store the validated image and ground truth plots --> to be moved to the respective folders
if not os.path.exists(r'../unet/val_images_temp_2d/'):
os.makedirs(r'../unet/val_images_temp_2d/')
val_path = r'../unet/val_images_temp_2d/'
# Save the validation images and ground truths
if not os.path.exists(r'../unet/val_images_save_2d/'):
os.makedirs(r'../unet/val_images_save_2d/')
image_path = r'../unet/val_images_save_2d/'
# pad the images so that they are divisible by 16
def Pad_images(image):
orig_shape = list(image.size())
original_x = orig_shape[2]
original_y = orig_shape[3]
new_x = (16 - (original_x % 16)) + original_x
new_y = (16 - (original_y % 16)) + original_y
new_shape = [new_x, new_y]
b, c, h, w = image.shape
m = image.min()
x_max = new_shape[0]
y_max = new_shape[1]
result = torch.Tensor(b, c, x_max, y_max).fill_(m)
xx = (x_max - h) // 2
yy = (y_max - w) // 2
result[:, :, xx:xx + h, yy:yy + w] = image
return result, tuple([xx, yy]) # result is a torch tensor in CPU --> have to move to GPU
# pass the padded image, the indices and the original shape
def UnPad_imges(image, indices, org_shape):
b, c, h, w = org_shape
xx = indices[0]
yy = indices[1]
return image[:, :, xx:xx + h, yy:yy + w] # image is a torch tensor --> have to move to GPU
# reset the parameters of the model
def reset_weights(m):
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or \
isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.ConvTranspose3d):
m.reset_parameters()
# save the masks
def save_plots_mask(target, idx):
out_path = os.path.join(val_path, f"{idx}_gt" + "." + 'png')
out_save_path = os.path.join(image_path, f"{idx}_gt" + "." + 'png')
target = target.squeeze()
target = np.array(target.cpu())
plt.imsave(out_save_path, target)
image_file_name = str(idx) + "_gt"
plt.title = image_file_name
plt.imsave(out_path, target, format='png')
plt.close()
# save the predictions
def save_plots_pred(pred, idx):
out_path = os.path.join(val_path, f"{idx}_pred" + "." + 'png')
out_save_path = os.path.join(image_path, f"{idx}_pred" + "." + 'png')
soft_pred_log = soft(pred)
final_pred_log = torch.argmax(soft_pred_log, dim=1)
# Post Processing after softmax and argmax
final_pred_log = keep_largest(final_pred_log)
final_pred_log = np.array(final_pred_log.cpu().squeeze())
plt.imsave(out_save_path, final_pred_log)
image_file_name = str(idx) + "_pred"
plt.title = image_file_name
plt.imsave(out_path, final_pred_log, format='png')
plt.close()
# save the images
def save_plots_image(img, idx):
out_path = os.path.join(val_path, f"{idx}_image" + "." + 'png')
out_save_path = os.path.join(image_path, f"{idx}_image" + "." + 'png')
final_image = np.array(img.cpu().squeeze())
plt.imsave(out_save_path, final_image)
image_file_name = str(idx) + "_image"
plt.title = image_file_name
plt.imsave(out_path, final_image, format='png')
plt.close()
class Train2D(pl.LightningModule):
def __init__(self):
super(Train2D, self).__init__()
self.net = my_model
self.loss_function = loss_func
def forward(self, x):
return self.net(x) # returns output of the model --> B Classes H W
def training_step(self, batch, batch_idx):
img, mask = batch["image"], batch["mask"] # image --> torch.float(), mask --> torch.Long
img = img.float() # B Channels H W
mask = mask.long() # B Channels H W
# image passed through the model
out = self(img) # B Classes H W
# calculate loss
loss = self.loss_function(out, mask)
# calculate softmax of the prediction
soft_out = soft(out) # softmax of the prediction
mask = mask.squeeze(dim=1) # B H W
""" Calculation of metrics using Torchmetrics"""
# # calculate iou
# iou_all = IOU_metric(soft_out, mask)
# # train_iou = iou_all.mean()
# iou_all = iou_all[iou_all != -1.]
# if len(iou_all) == 0:
# train_iou = torch.tensor(0.0).cuda()
# else:
# train_iou = iou_all.mean()
# # calculate dice score
# dice_all = f1_metric(soft_out, mask)
# # train_dice = dice_all.mean()
# dice_all_np = dice_all.cpu().numpy()
# dice_all_np = dice_all_np[~np.isnan(dice_all_np)]
# dice_all = (torch.from_numpy(dice_all_np).cuda())
# if len(dice_all) == 0:
# train_dice = torch.tensor(0.0).cuda()
# else:
# train_dice = dice_all.mean()
""" Calculation of metrics using Torchmetrics functional"""
# iou
iou_all = iou(soft_out, mask, absent_score=-1., num_classes=4, reduction='none', ignore_index=None)
iou_all = iou_all[iou_all != -1.]
if len(iou_all) == 0:
train_iou = torch.tensor(0.0).cuda()
else:
train_iou = iou_all.mean()
# dice score
dice_all = dice_score(soft_out, mask, bg=True, no_fg_score=-1., reduction='none')
dice_all = dice_all[dice_all != -1.]
if len(dice_all) == 0:
train_dice = torch.tensor(0.0).cuda()
else:
train_dice = dice_all.mean()
# logger --> log the train_loss
self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=False)
return {"loss": loss, "train_iou": train_iou, "train_dice": train_dice}
def validation_step(self, batch, batch_idx):
img, mask = batch["image"], batch["mask"] # image --> torch.float(), mask --> torch.Long
img = img.float() # B Channels H W
mask = mask.long() # B Channels H W
###############################################
save_plots_image(img, batch_idx) # save the images
save_plots_mask(mask, batch_idx) # save the masks
###############################################
# pad the image
padded_image, ind = Pad_images(img)
padded_image = padded_image.cuda()
# padded image passed through the model
out = self(padded_image).cuda() # B Classes H W
# unpad the image
unpadded_prediction = UnPad_imges(out, ind, img.shape)
unpadded_prediction = unpadded_prediction.cuda()
###############################################
save_plots_pred(unpadded_prediction, batch_idx) # save the predictions
###############################################
# calculate loss
loss = self.loss_function(unpadded_prediction, mask)
# calculate softmax of the prediction
soft_out = soft(unpadded_prediction)
mask = mask.squeeze(dim=1) # B H W
""" Calculation of metrics using Torchmetrics"""
# # calculate iou
# iou_all = IOU_metric(soft_out, mask)
# # val_iou = iou_all.mean()
# iou_all = iou_all[iou_all != -1.]
# if len(iou_all) == 0:
# val_iou = torch.tensor(0.0).cuda()
# else:
# val_iou = iou_all.mean()
# # calculate dice score
# dice_all = f1_metric(soft_out, mask)
# # val_dice = dice_all.mean()
# dice_all_np = dice_all.cpu().numpy()
# dice_all_np = dice_all_np[~np.isnan(dice_all_np)]
# dice_all = (torch.from_numpy(dice_all_np).cuda())
# if len(dice_all) == 0:
# val_dice = torch.tensor(0.0).cuda()
# else:
# val_dice = dice_all.mean()
""" Calculation of metrics using Torchmetrics functional"""
# iou
iou_all = iou(soft_out, mask, absent_score=-1., num_classes=4, reduction='none', ignore_index=None)
iou_all = iou_all[iou_all != -1.]
if len(iou_all) == 0:
val_iou = torch.tensor(0.0).cuda()
else:
val_iou = iou_all.mean()
# dice score
dice_all = dice_score(soft_out, mask, bg=True, no_fg_score=-1., reduction='none')
dice_all = dice_all[dice_all != -1.]
if len(dice_all) == 0:
val_dice = torch.tensor(0.0).cuda()
else:
val_dice = dice_all.mean()
# logger --> log the val_loss
self.log('val_loss', loss, on_step=True, on_epoch=False, prog_bar=False)
return {'loss': loss, 'val_iou': val_iou, 'val_dice': val_dice}
def training_epoch_end(self, train_step_outputs):
"""-----Calculate and logs the average train loss, IoU score and Dice Score-----"""
avg_train_loss = torch.stack([x['loss'] for x in train_step_outputs]).mean()
avg_train_iou = torch.stack([x['train_iou'] for x in train_step_outputs]).mean()
avg_train_dice = torch.stack([x['train_dice'] for x in train_step_outputs]).mean()
self.log('avg_train_loss', avg_train_loss, on_step=False, on_epoch=True, prog_bar=True)
self.log('avg_train_iou', avg_train_iou, on_step=False, on_epoch=True, prog_bar=True)
self.log('avg_train_dice', avg_train_dice, on_step=False, on_epoch=True, prog_bar=True)
def validation_epoch_end(self, val_step_outputs):
"""-----Calculate and logs the average validation loss, IoU score and Dice Score-----"""
avg_val_loss = torch.stack([x['loss'] for x in val_step_outputs]).mean()
avg_val_iou = torch.stack([x['val_iou'] for x in val_step_outputs]).mean()
avg_val_dice = torch.stack([x['val_dice'] for x in val_step_outputs]).mean()
self.log('avg_val_loss', avg_val_loss, on_step=False, on_epoch=True, prog_bar=True)
self.log('avg_val_iou', avg_val_iou, on_step=False, on_epoch=True, prog_bar=True)
self.log('avg_val_dice', avg_val_dice, on_step=False, on_epoch=True, prog_bar=True)
return {'avg_val_loss': avg_val_loss, 'avg_val_iou': avg_val_iou, 'avg_val_dice': avg_val_dice}
def configure_optimizers(self):
"""-----Optimizers and LR Schedulers-----"""
if optim_choice == 'adam':
optim = torch.optim.Adam(self.net.parameters(), lr=learning_rate, eps=1e-8, weight_decay=1e-5, amsgrad=True)
else:
raise ValueError("Wrong optimizer!")
if scheduler_choice == 'plateau':
scheduler = ReduceLROnPlateau(optim, mode='min', factor=LR_decay_rate, patience=20)
elif scheduler_choice == 'step':
scheduler = StepLR(optim, step_size=10, gamma=LR_decay_rate, last_epoch=-1)
else:
raise ValueError("Wrong scheduler!")
return {
"optimizer": optim,
"lr_scheduler": scheduler,
'monitor': 'avg_train_loss'
}
def run_training():
"""--------------------------------------5 fold Cross Validation--------------------------------------"""
for fold, (train_idx, val_idx) in enumerate(splits.split(np.arange(len(concatenated_dataset)))):
print(len(train_idx), len(val_idx))
print("--------------------------", "Fold", fold + 1, "--------------------------")
print("Train Batch Size:", batch_size_train,
"Val Batch Size:", batch_size_val,
"Learning Rate:", learning_rate,
"Max epochs:", max_epochs)
"""-------------------Train the model for "max_epochs" for each fold-------------------"""
# training dataset
training_data = DataLoader(train_loader_ACDC(transform=train_transform, train_index=train_idx),
batch_size=batch_size_train,
shuffle=True, num_workers=2)
# validation dataset
validation_data = DataLoader(val_loader_ACDC(transform=val_transform, val_index=val_idx),
batch_size=batch_size_val, shuffle=False, num_workers=2)
# init the model
model = Train2D()
# name of the model
name = str(model_choice) + "_" + str(drop_rate) + "_" + str(datetime.date.today()) + "_Fold_" + str(fold + 1)
# Checkpoint callback and Early Stopping
checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=checkpoint_path,
save_top_k=1,
save_last=True,
verbose=True,
monitor='avg_val_iou',
mode='max',
filename=name + "_" + '{epoch}-{avg_val_iou:.4f}',
)
early_stop_callback = pl.callbacks.EarlyStopping(monitor='avg_val_loss',
min_delta=0.00,
patience=patience,
verbose=False,
mode='min')
# Tensorboard logger --> tensorboard --logdir=tb_logs
tensorboard_logger = TensorBoardLogger(tb_path, name=name)
# CSV logger
csv_logger = CSVLogger(csv_path, name=name)
# Trainer for training
trainer = Trainer(max_epochs=max_epochs, callbacks=[early_stop_callback, checkpoint_callback],
gpus=1, logger=[tensorboard_logger, csv_logger], fast_dev_run=False, log_every_n_steps=2)
# Training the model
trainer.fit(model, train_dataloader=training_data, val_dataloaders=validation_data)
# Save the best model
best_model_path = checkpoint_callback.best_model_path
model = model.load_from_checkpoint(best_model_path)
model.eval().cuda()
fname = str(model_choice) + "_Best_" + str(drop_rate) + "_Fold_" + str(fold + 1)
if not os.path.exists(r'../unet/best_models/'):
os.makedirs(r'../unet/best_models/')
torch.save(model, str(Path('../unet/best_models/', fname + '.pt')))
# Folders to save the validation images for each fold
if not os.path.exists(os.path.join(r'../unet/', name, f"{fold + 1}_Fold")):
os.makedirs(os.path.join(r'../unet/', name, f"{fold + 1}_Fold"))
val_images_path = os.path.join(r'../unet/', name, f"{fold + 1}_Fold")
# Move the validated images to the respective folders
for filename in glob.glob(os.path.join(val_path, '*.*')):
shutil.move(filename, val_images_path)
# Save plots --> Loss, IoU and Dice
plot_out_path = str(Path(r"../unet/Plots/", name))
if not os.path.exists(plot_out_path):
os.makedirs(plot_out_path)
event_acc = ea(str(Path(r"../unet/tb_logs/", name, "version_0")))
event_acc.Reload()
_, _, training_loss = zip(*event_acc.Scalars('avg_train_loss'))
_, _, validation_loss = zip(*event_acc.Scalars('avg_val_loss'))
_, _, training_iou = zip(*event_acc.Scalars('avg_train_iou'))
_, _, validation_iou = zip(*event_acc.Scalars('avg_val_iou'))
_, _, training_dice = zip(*event_acc.Scalars('avg_train_dice'))
_, _, validation_dice = zip(*event_acc.Scalars('avg_val_dice'))
t_loss, v_loss, t_iou, v_iou, t_dice, v_dice = np.array(training_loss), np.array(validation_loss), \
np.array(training_iou), np.array(validation_iou), \
np.array(training_dice), np.array(validation_dice)
min_length = min(len(t_loss), len(v_loss), len(t_iou), len(v_iou), len(t_dice), len(v_dice))
total_epochs = np.arange(1, min_length + 1)
# Save the Loss, IoU and Dice plots
plt.figure(1)
plt.rcParams.update({'font.size': 15})
plt.plot(total_epochs, t_loss[0:min_length], 'X-', label='Training Loss', linewidth=2.0)
plt.plot(total_epochs, v_loss[0:min_length], 'o-', label='Validation Loss', linewidth=2.0)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.minorticks_on()
plt.grid(which='minor', linestyle='--')
plt.savefig(str(Path(plot_out_path, 'Loss_Plot.png')), bbox_inches='tight', format='png', dpi=300)
plt.close() # Always plt.close() to save memory
plt.figure(2)
plt.rcParams.update({'font.size': 15})
plt.plot(total_epochs, t_iou[0:min_length], 'X-', label='Training IOU', linewidth=2.0)
plt.plot(total_epochs, v_iou[0:min_length], 'o-', label='Validation IOU', linewidth=2.0)
plt.xlabel('Epoch')
plt.ylabel('IOU')
plt.legend(loc='lower right')
plt.minorticks_on()
plt.grid(which='minor', linestyle='--')
plt.savefig(str(Path(plot_out_path, 'Iou_Plot.png')), bbox_inches='tight', format='png', dpi=300)
plt.close() # Always plt.close() to save memory
plt.figure(3)
plt.rcParams.update({'font.size': 15})
plt.plot(total_epochs, t_dice[0:min_length], 'X-', label='Training Dice', linewidth=2.0)
plt.plot(total_epochs, v_dice[0:min_length], 'o-', label='Validation Dice', linewidth=2.0)
plt.xlabel('Epoch')
plt.ylabel('Dice Score')
plt.legend(loc='lower right')
plt.minorticks_on()
plt.grid(which='minor', linestyle='--')
plt.savefig(str(Path(plot_out_path, 'Dice_Plot.png')), bbox_inches='tight', format='png', dpi=300)
plt.close() # Always plt.close() to save memory
# reset model parameters after each fold
model.apply(reset_weights)
if __name__ == "__main__":
run_training()