Diff of /trainer.py [000000] .. [72db80]

Switch to side-by-side view

--- a
+++ b/trainer.py
@@ -0,0 +1,216 @@
+import csv
+import copy
+import time
+from tqdm import tqdm
+import torch
+import numpy as np
+import os
+from datetime import datetime
+import pathlib
+import matplotlib.pyplot as plt
+
+
+def load_checkpoint(bpath):
+
+    checkpoint_folder = os.path.join(bpath, 'checkpoint')
+    checkpoint_filename = os.path.join(
+        checkpoint_folder, 'checkpoint.pth.tar')
+
+    bestweights_filename = os.path.join(
+        checkpoint_folder, 'best_weights_checkpoint.pth.tar')
+
+    file = pathlib.Path(checkpoint_filename)
+
+    if not file.exists():
+        return None, None, None, None, None, None
+
+    file = pathlib.Path(bestweights_filename)
+
+    best_weight = None
+    if file.exists():
+        best_weight = torch.load(bestweights_filename)
+        best_weight = best_weight['state_dict']
+
+    checkpoint = torch.load(checkpoint_filename)
+
+    return checkpoint['epoch'], checkpoint['state_dict'], best_weight, checkpoint['optimizer'], checkpoint['best_loss'], checkpoint['best_pred']
+
+
+def save_checkpoint(bpath, state, is_best=False):
+
+    checkpoint_folder = os.path.join(bpath, 'checkpoint')
+
+    if is_best:
+        best_pred = state['best_pred']
+        with open(os.path.join(checkpoint_folder, 'best_pred.txt'), 'w') as f:
+            f.write(str(best_pred))
+
+        best_pred = state['best_loss']
+        with open(os.path.join(checkpoint_folder, 'best_loss.txt'), 'w') as f:
+            f.write(str(best_pred))
+
+        torch.save(state, os.path.join(checkpoint_folder,
+                                       'best_weights_checkpoint.pth.tar'))
+
+    torch.save(state, os.path.join(checkpoint_folder,
+                                   'checkpoint.pth.tar'))
+
+
+def train_model(model, criterion, dataloaders, optimizer, scheduler, metrics, bpath, num_epochs=3):
+
+    start_epoch, state_dict, bweights, optm, bloss, bpred = load_checkpoint(
+        bpath)
+
+    if start_epoch is not None:
+        print("")
+        print("NEW CHECKPOINT FOUND! LAST EPOCH ", start_epoch)
+        print("")
+        model.load_state_dict(state_dict)
+        start_epoch += 1
+
+        best_model_wts = copy.deepcopy(bweights)
+        best_loss = float(bloss)
+
+        best_Train_dice = 1e-5
+        best_Valid_dice = bpred
+    else:
+        start_epoch = 1
+        best_model_wts = copy.deepcopy(model.state_dict())
+        best_loss = 1e10
+
+        best_Train_dice = 1e-5
+        best_Valid_dice = 1e-5
+
+    since = time.time()
+
+    # Use gpu if available
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    model.to(device)
+    # Initialize the log file for training and testing loss and metrics
+    fieldnames = ['epoch', 'Train_loss', 'Valid_loss'] + \
+        [f'Train_{m}' for m in metrics.keys()] + \
+        [f'Valid_{m}' for m in metrics.keys()]
+    with open(os.path.join(bpath, 'log.csv'), 'w', newline='') as csvfile:
+        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
+        writer.writeheader()
+
+    for epoch in range(start_epoch, num_epochs+1):
+        print('Epoch {}/{}'.format(epoch, num_epochs))
+        print('-' * 10)
+        # Each epoch has a training and validation phase
+        # Initialize batch summary
+        batchsummary = {a: [0] for a in fieldnames}
+
+        for phase in ['Train', 'Valid']:
+            if phase == 'Train':
+                model.train()  # Set model to training mode
+            else:
+                model.eval()   # Set model to evaluate mode
+
+            # Iterate over data.
+
+            for sample in tqdm(iter(dataloaders[phase])):
+
+                inputs = sample['image'].to(device)
+                masks = sample['mask'].to(device)
+
+                # zero the parameter gradients
+                optimizer.zero_grad()
+
+                # track history if only in train
+                with torch.set_grad_enabled(phase == 'Train'):
+                    outputs = model(inputs)
+                    # loss = criterion(outputs['out'], masks)
+                    loss = criterion(outputs, masks)
+
+                    # y_pred = outputs['out'].data.cpu().numpy().squeeze(1)
+                    y_pred = outputs.data.cpu().numpy().squeeze(1)
+                    y_true = masks.data.cpu().numpy().squeeze(1)
+
+                    for name, metric in metrics.items():
+                        if name == 'dice' or name == 'dice_target':
+                            # Use a classification threshold of 0.5
+                            val_metric = metric(y_pred > 0.5, y_true > 0)
+
+                            if val_metric is not None:
+                                batchsummary[f'{phase}_{name}'].append(
+                                    val_metric)
+
+                    # backward + optimize only if in training phase
+                    if phase == 'Train':
+                        loss.backward()
+                        optimizer.step()
+
+            batchsummary['epoch'] = epoch
+            epoch_loss = loss
+            batchsummary[f'{phase}_loss'] = epoch_loss.item()
+            print('{} Loss: {:.4f}'.format(phase, loss))
+
+        print('New LR: ', scheduler.get_last_lr())
+        scheduler.step()
+
+        for field in fieldnames[3:]:
+            batchsummary[field] = np.mean(batchsummary[field])
+
+        print(batchsummary)
+
+        epoch_valid_dice = np.mean(batchsummary['Valid_dice_tumor'])
+        is_best = False
+        with open(os.path.join(bpath, 'log.csv'), 'a', newline='') as csvfile:
+            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
+            writer.writerow(batchsummary)
+
+            SAVE_BESTLOSS_WEIGTH = False
+            if SAVE_BESTLOSS_WEIGTH:
+                # deep copy the model
+                if phase == 'Valid' and loss < best_loss:
+                    print('\nnew best loss: {:.4f} in epoch {}\n'.format(
+                        loss, epoch))
+                    best_loss = loss
+                    best_model_wts = copy.deepcopy(model.state_dict())
+                    now = datetime.now()
+                    str_datetime = now.strftime("%Y%m%d_%H_%M_%S")
+
+                    best_Train_dice = np.mean(batchsummary['Train_dice'])
+                    best_Valid_dice = np.mean(batchsummary['Valid_dice'])
+
+                    torch.save(model, os.path.join(
+                        bpath, 'weights_partial_epch{}_{}.pt'.format(epoch, str_datetime)))
+            else:
+                # deep copy the model
+                if phase == 'Valid' and epoch_valid_dice > best_Valid_dice:
+                    is_best = True
+                    print('\nNew valid dice: {:.4f} in epoch {}\n'.format(
+                        epoch_valid_dice, epoch))
+                    best_loss = loss.item()
+                    best_model_wts = copy.deepcopy(model.state_dict())
+                    now = datetime.now()
+                    str_datetime = now.strftime("%Y%m%d_%H_%M_%S")
+
+                    best_Train_dice = np.mean(batchsummary['Train_dice'])
+                    best_Valid_dice = epoch_valid_dice
+
+                    torch.save(model, os.path.join(
+                        bpath, 'weights_partial_diceval_epch{}_{}.pt'.format(epoch, str_datetime)))
+
+                    # torch.save(model, os.path.join(
+                    #     bpath, 'model_weights_partial.pt'))
+
+            save_checkpoint(bpath, {
+                'epoch': epoch,
+                'state_dict': model.state_dict(),
+                'optimizer': optimizer.state_dict(),
+                'best_pred': best_Valid_dice,
+                'best_loss': best_loss
+            }, is_best=is_best)
+
+    time_elapsed = time.time() - since
+    print('Training complete in {:.0f}m {:.0f}s'.format(
+        time_elapsed // 60, time_elapsed % 60))
+    print('Lowest by valid dice Loss: {:4f}'.format(best_loss))
+    print('Max valid Dice: {:4f}'.format(best_Valid_dice))
+
+    # load best model weights
+    model.load_state_dict(best_model_wts)
+
+    return best_Train_dice, best_Valid_dice