Diff of /solver.py [000000] .. [6f9c00]

Switch to side-by-side view

--- a
+++ b/solver.py
@@ -0,0 +1,215 @@
+import glob
+import os
+
+import numpy as np
+import torch
+from nn_common_modules import losses as additional_losses
+from torch.optim import lr_scheduler
+
+import utils.common_utils as common_utils
+from utils.log_utils import LogWriter
+
+CHECKPOINT_DIR = 'checkpoints'
+CHECKPOINT_EXTENSION = 'pth.tar'
+
+
+class Solver(object):
+
+    def __init__(self,
+                 model,
+                 exp_name,
+                 device,
+                 num_class,
+                 optim=torch.optim.Adam,
+                 optim_args={},
+                 loss_func=additional_losses.CombinedLoss(),
+                 model_name='quicknat',
+                 labels=None,
+                 num_epochs=10,
+                 log_nth=5,
+                 lr_scheduler_step_size=5,
+                 lr_scheduler_gamma=0.5,
+                 use_last_checkpoint=True,
+                 exp_dir='experiments',
+                 log_dir='logs'):
+
+        self.device = device
+        self.model = model
+
+        self.model_name = model_name
+        self.labels = labels
+        self.num_epochs = num_epochs
+        if torch.cuda.is_available():
+            self.loss_func = loss_func.cuda(device)
+        else:
+            self.loss_func = loss_func
+        self.optim = optim(model.parameters(), **optim_args)
+        self.scheduler = lr_scheduler.StepLR(self.optim, step_size=lr_scheduler_step_size,
+                                             gamma=lr_scheduler_gamma)
+
+        exp_dir_path = os.path.join(exp_dir, exp_name)
+        common_utils.create_if_not(exp_dir_path)
+        common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR))
+        self.exp_dir_path = exp_dir_path
+
+        self.log_nth = log_nth
+        self.logWriter = LogWriter(num_class, log_dir, exp_name, use_last_checkpoint, labels)
+
+        self.use_last_checkpoint = use_last_checkpoint
+
+        self.start_epoch = 1
+        self.start_iteration = 1
+
+        self.best_ds_mean = 0
+        self.best_ds_mean_epoch = 0
+
+        if use_last_checkpoint:
+            self.load_checkpoint()
+
+    # TODO:Need to correct the CM and dice score calculation.
+    def train(self, train_loader, val_loader):
+        """
+        Train a given model with the provided data.
+
+        Inputs:
+        - train_loader: train data in torch.utils.data.DataLoader
+        - val_loader: val data in torch.utils.data.DataLoader
+        """
+        model, optim, scheduler = self.model, self.optim, self.scheduler
+        dataloaders = {
+            'train': train_loader,
+            'val': val_loader
+        }
+
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            model.cuda(self.device)
+
+        print('START TRAINING. : model name = %s, device = %s' % (
+            self.model_name, torch.cuda.get_device_name(self.device)))
+        current_iteration = self.start_iteration
+        for epoch in range(self.start_epoch, self.num_epochs + 1):
+            print("\n==== Epoch [ %d  /  %d ] START ====" % (epoch, self.num_epochs))
+            for phase in ['train', 'val']:
+                print("<<<= Phase: %s =>>>" % phase)
+                loss_arr = []
+                out_list = []
+                y_list = []
+                if phase == 'train':
+                    model.train()
+                    scheduler.step()
+                else:
+                    model.eval()
+                for i_batch, sample_batched in enumerate(dataloaders[phase]):
+                    X = sample_batched[0].type(torch.FloatTensor)
+                    y = sample_batched[1].type(torch.LongTensor)
+                    w = sample_batched[2].type(torch.FloatTensor)
+
+                    if model.is_cuda:
+                        X, y, w = X.cuda(self.device, non_blocking=True), y.cuda(self.device,
+                                                                                 non_blocking=True), w.cuda(self.device,
+                                                                                                            non_blocking=True)
+
+                    output = model(X)
+                    loss = self.loss_func(output, y, w)
+                    if phase == 'train':
+                        optim.zero_grad()
+                        loss.backward()
+                        optim.step()
+                        if i_batch % self.log_nth == 0:
+                            self.logWriter.loss_per_iter(loss.item(), i_batch, current_iteration)
+                        current_iteration += 1
+
+                    loss_arr.append(loss.item())
+
+                    _, batch_output = torch.max(output, dim=1)
+                    out_list.append(batch_output.cpu())
+                    y_list.append(y.cpu())
+
+                    del X, y, w, output, batch_output, loss
+                    torch.cuda.empty_cache()
+                    if phase == 'val':
+                        if i_batch != len(dataloaders[phase]) - 1:
+                            print("#", end='', flush=True)
+                        else:
+                            print("100%", flush=True)
+
+                with torch.no_grad():
+                    out_arr, y_arr = torch.cat(out_list), torch.cat(y_list)
+                    self.logWriter.loss_per_epoch(loss_arr, phase, epoch)
+                    index = np.random.choice(len(dataloaders[phase].dataset.X), 3, replace=False)
+                    self.logWriter.image_per_epoch(model.predict(dataloaders[phase].dataset.X[index], self.device),
+                                                   dataloaders[phase].dataset.y[index], phase, epoch)
+                    self.logWriter.cm_per_epoch(phase, out_arr, y_arr, epoch)
+                    ds_mean = self.logWriter.dice_score_per_epoch(phase, out_arr, y_arr, epoch)
+                    if phase == 'val':
+                        if ds_mean > self.best_ds_mean:
+                            self.best_ds_mean = ds_mean
+                            self.best_ds_mean_epoch = epoch
+
+            print("==== Epoch [" + str(epoch) + " / " + str(self.num_epochs) + "] DONE ====")
+            self.save_checkpoint({
+                'epoch': epoch + 1,
+                'start_iteration': current_iteration + 1,
+                'arch': self.model_name,
+                'state_dict': model.state_dict(),
+                'optimizer': optim.state_dict(),
+                'scheduler': scheduler.state_dict(),
+                'best_ds_mean': self.best_ds_mean,
+                'best_ds_mean_epoch': self.best_ds_mean_epoch
+            }, os.path.join(self.exp_dir_path, CHECKPOINT_DIR,
+                            'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION)) 
+
+        print('FINISH.')
+        self.logWriter.close()
+
+
+    def save_best_model(self, path):
+        """
+        Save model with its parameters to the given path. Conventionally the
+        path should end with "*.model".
+        Inputs:
+        - path: path string
+        """
+        print('Saving model... %s' % path)
+        print('Best Model at Epoch: ' + str(self.best_ds_mean_epoch))
+        self.load_checkpoint(self.best_ds_mean_epoch)
+
+        torch.save(self.model, path)
+
+    def save_checkpoint(self, state, filename):
+        torch.save(state, filename)
+
+    def load_checkpoint(self, epoch=None):
+        if epoch is not None:
+            checkpoint_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR,
+                                           'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION)
+            self._load_checkpoint_file(checkpoint_path)
+        else:
+            all_files_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, '*.' + CHECKPOINT_EXTENSION)
+            list_of_files = glob.glob(all_files_path)
+            if len(list_of_files) > 0:
+                checkpoint_path = max(list_of_files, key=os.path.getctime)
+                self._load_checkpoint_file(checkpoint_path)
+            else:
+                self.logWriter.log(
+                    "=> no checkpoint found at '{}' folder".format(os.path.join(self.exp_dir_path, CHECKPOINT_DIR)))
+
+    def _load_checkpoint_file(self, file_path):
+        self.logWriter.log("=> loading checkpoint '{}'".format(file_path))
+        checkpoint = torch.load(file_path)
+        self.start_epoch = checkpoint['epoch']
+        self.start_iteration = checkpoint['start_iteration']
+        self.model.load_state_dict(checkpoint['state_dict'])
+        self.optim.load_state_dict(checkpoint['optimizer'])
+        if 'best_ds_mean' in checkpoint.keys():
+            self.best_ds_mean = checkpoint['best_ds_mean']
+            self.best_ds_mean_epoch = checkpoint['best_ds_mean_epoch']
+
+        for state in self.optim.state.values():
+            for k, v in state.items():
+                if torch.is_tensor(v):
+                    state[k] = v.to(self.device)
+
+        self.scheduler.load_state_dict(checkpoint['scheduler'])
+        self.logWriter.log("=> loaded checkpoint '{}' (epoch {})".format(file_path, checkpoint['epoch']))
\ No newline at end of file