Diff of /train_3d.py [000000] .. [94d9b6]

Switch to side-by-side view

--- a
+++ b/train_3d.py
@@ -0,0 +1,91 @@
+import time
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+import torch.optim as optim
+import torchvision
+from torchvision import datasets, models
+from torchvision import transforms as T
+from torch.utils.data import DataLoader, Dataset
+import numpy as np
+import matplotlib.pyplot as plt
+import os
+import time
+import pandas as pd
+from skimage import io, transform
+import matplotlib.image as mpimg
+from PIL import Image
+from sklearn.metrics import roc_auc_score
+import torch.nn.functional as F
+import scipy
+import random
+import pickle
+import scipy.io as sio
+import itertools
+from scipy.ndimage.interpolation import shift
+
+import warnings
+warnings.filterwarnings("ignore")
+plt.ion()
+
+from utils import *
+
+def train_model_3d(model, optimizer,dataloader, data_sizes, batch_size, num_epochs = 100, verbose = False):
+    since = time.time()
+    best_loss = np.inf
+    loss_hist = {'train':[],'validate':[]}
+    dice_score_0_hist = {'train':[],'validate':[]}
+    dice_score_1_hist = {'train':[],'validate':[]}
+    dice_score_2_hist = {'train':[],'validate':[]}
+    for i in range(num_epochs):
+        for phase in ['train', 'validate']:
+            running_loss = 0
+            running_dice_score_class_0 = 0
+            running_dice_score_class_1 = 0
+            running_dice_score_class_2 = 0
+
+            if phase == 'train':
+                model.train(True)
+            else:
+                model.train(False)
+
+            for data in dataloader[phase]:
+                optimizer.zero_grad()
+                input, target,_ = data
+                input = Variable(input).cuda()
+                true = Variable(torch.transpose(target,2,1).contiguous().view(-1,4,256,256)).cuda()
+                output = model(input)
+                output_reshaped = torch.transpose(output,2,1).contiguous().view(-1,4,256,256)
+                loss = dice_loss(true,output_reshaped,p=2.5)
+                if phase == 'train':
+                    loss.backward()
+                    optimizer.step()
+                running_loss += loss.data[0] * batch_size
+                dice_score_batch = dice_score(true,output_reshaped)
+                running_dice_score_class_0 += dice_score_batch[0] * batch_size
+                running_dice_score_class_1 += dice_score_batch[1] * batch_size
+                running_dice_score_class_2 += dice_score_batch[2] * batch_size
+            epoch_loss = running_loss/data_sizes[phase]
+            loss_hist[phase].append(epoch_loss) 
+            epoch_dice_score_0 = running_dice_score_class_0/data_sizes[phase]
+            epoch_dice_score_1 = running_dice_score_class_1/data_sizes[phase]
+            epoch_dice_score_2 = running_dice_score_class_2/data_sizes[phase]
+            dice_score_0_hist[phase].append(epoch_dice_score_0)
+            dice_score_1_hist[phase].append(epoch_dice_score_1)
+            dice_score_2_hist[phase].append(epoch_dice_score_2)
+            if verbose or i%10 == 0:
+                print('Epoch: {}, Phase: {}, epoch loss: {:.4f}, Dice Score (class 0): {:.4f}, Dice Score (class 1): {:.4f},Dice Score (class 2): {:.4f}'.format(i,phase,epoch_loss, epoch_dice_score_0, epoch_dice_score_1, epoch_dice_score_2))
+                print('-'*10)
+
+        if phase == 'validate' and epoch_loss < best_loss:
+            best_loss = epoch_loss
+            best_model_wts = model.state_dict() 
+
+    print('-'*50)    
+    time_elapsed = time.time() - since
+    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
+    print('Best val dice loss: {:4f}'.format(best_loss))
+
+    model.load_state_dict(best_model_wts)
+
+    return model, loss_hist, dice_score_0_hist, dice_score_1_hist, dice_score_2_hist
\ No newline at end of file