Diff of /semseg/train.py [000000] .. [cc8b8f]

Switch to side-by-side view

--- a
+++ b/semseg/train.py
@@ -0,0 +1,98 @@
+import os
+import time
+
+import numpy as np
+import torch
+
+from semseg.loss import get_multi_dice_loss
+from config.config import LEARNING_RATE_REDUCTION_FACTOR
+from semseg.utils import multi_dice_coeff
+
+
+def train_model(net, optimizer, train_data, config, device=None, logs_folder=None):
+
+    print('Start training...')
+    net = net.to(device)
+    # train loop
+    for epoch in range(config.epochs):
+
+        epoch_start_time = time.time()
+        running_loss = 0.0
+
+        # lower learning rate
+        if epoch == config.low_lr_epoch:
+            for param_group in optimizer.param_groups:
+                config.lr = config.lr / LEARNING_RATE_REDUCTION_FACTOR
+                param_group['lr'] = config.lr
+
+        # switch to train mode
+        net.train()
+
+        for i, data in enumerate(train_data):
+
+            inputs, labels = data['t1']['data'], data['label']['data']
+            if config.cuda:
+                inputs, labels = inputs.cuda(), labels.cuda()
+
+            # forward pass
+            outputs = net(inputs)
+
+            # get multi dice loss
+            loss = get_multi_dice_loss(outputs, labels, device=device)
+
+            # empty gradients, perform backward pass and update weights
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+            # save and print statistics
+            running_loss += loss.data
+
+        epoch_end_time = time.time()
+        epoch_elapsed_time = epoch_end_time - epoch_start_time
+
+        # print statistics
+        print('  [Epoch {:04d}] - Train dice loss: {:.4f} - Time: {:.1f}'
+              .format(epoch + 1, running_loss / (i + 1), epoch_elapsed_time))
+
+        # switch to eval mode
+        net.eval()
+
+        # only validate every 'val_epochs' epochs
+        if epoch % config.val_epochs == 0:
+            if logs_folder is not None:
+                checkpoint_path = os.path.join(logs_folder, 'model_epoch_{:04d}.pht'.format(epoch))
+                torch.save(net.state_dict(), checkpoint_path)
+
+    print('Training ended!')
+    return net
+
+
+def val_model(net, val_data, config, device=None):
+
+    print("Start Validation...")
+    net = net.to(device)
+    # val loop
+    multi_dices = list()
+    with torch.no_grad():
+        net.eval()
+        for i, data in enumerate(val_data):
+            print("Iter {} on {}".format(i+1,len(val_data)))
+
+            inputs, labels = data['t1']['data'], data['label']['data']
+            if config.cuda: inputs, labels = inputs.cuda(), labels.cuda()
+
+            # forward pass
+            outputs = net(inputs)
+            outputs = torch.argmax(outputs, dim=1)  #     B x Z x Y x X
+            outputs_np = outputs.data.cpu().numpy() #     B x Z x Y x X
+            labels_np = labels.data.cpu().numpy()   # B x 1 x Z x Y x X
+            labels_np = labels_np[:,0]              #     B x Z x Y x X
+
+            multi_dice = multi_dice_coeff(labels_np,outputs_np,config.num_outs)
+            multi_dices.append(multi_dice)
+    multi_dices_np = np.array(multi_dices)
+    mean_multi_dice = np.mean(multi_dices_np)
+    std_multi_dice = np.std(multi_dices_np)
+    print("Multi-Dice: {:.4f} +/- {:.4f}".format(mean_multi_dice,std_multi_dice))
+    return multi_dices, mean_multi_dice, std_multi_dice