--- a +++ b/solver.py @@ -0,0 +1,215 @@ +import glob +import os + +import numpy as np +import torch +from nn_common_modules import losses as additional_losses +from torch.optim import lr_scheduler + +import utils.common_utils as common_utils +from utils.log_utils import LogWriter + +CHECKPOINT_DIR = 'checkpoints' +CHECKPOINT_EXTENSION = 'pth.tar' + + +class Solver(object): + + def __init__(self, + model, + exp_name, + device, + num_class, + optim=torch.optim.Adam, + optim_args={}, + loss_func=additional_losses.CombinedLoss(), + model_name='quicknat', + labels=None, + num_epochs=10, + log_nth=5, + lr_scheduler_step_size=5, + lr_scheduler_gamma=0.5, + use_last_checkpoint=True, + exp_dir='experiments', + log_dir='logs'): + + self.device = device + self.model = model + + self.model_name = model_name + self.labels = labels + self.num_epochs = num_epochs + if torch.cuda.is_available(): + self.loss_func = loss_func.cuda(device) + else: + self.loss_func = loss_func + self.optim = optim(model.parameters(), **optim_args) + self.scheduler = lr_scheduler.StepLR(self.optim, step_size=lr_scheduler_step_size, + gamma=lr_scheduler_gamma) + + exp_dir_path = os.path.join(exp_dir, exp_name) + common_utils.create_if_not(exp_dir_path) + common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR)) + self.exp_dir_path = exp_dir_path + + self.log_nth = log_nth + self.logWriter = LogWriter(num_class, log_dir, exp_name, use_last_checkpoint, labels) + + self.use_last_checkpoint = use_last_checkpoint + + self.start_epoch = 1 + self.start_iteration = 1 + + self.best_ds_mean = 0 + self.best_ds_mean_epoch = 0 + + if use_last_checkpoint: + self.load_checkpoint() + + # TODO:Need to correct the CM and dice score calculation. + def train(self, train_loader, val_loader): + """ + Train a given model with the provided data. + + Inputs: + - train_loader: train data in torch.utils.data.DataLoader + - val_loader: val data in torch.utils.data.DataLoader + """ + model, optim, scheduler = self.model, self.optim, self.scheduler + dataloaders = { + 'train': train_loader, + 'val': val_loader + } + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + model.cuda(self.device) + + print('START TRAINING. : model name = %s, device = %s' % ( + self.model_name, torch.cuda.get_device_name(self.device))) + current_iteration = self.start_iteration + for epoch in range(self.start_epoch, self.num_epochs + 1): + print("\n==== Epoch [ %d / %d ] START ====" % (epoch, self.num_epochs)) + for phase in ['train', 'val']: + print("<<<= Phase: %s =>>>" % phase) + loss_arr = [] + out_list = [] + y_list = [] + if phase == 'train': + model.train() + scheduler.step() + else: + model.eval() + for i_batch, sample_batched in enumerate(dataloaders[phase]): + X = sample_batched[0].type(torch.FloatTensor) + y = sample_batched[1].type(torch.LongTensor) + w = sample_batched[2].type(torch.FloatTensor) + + if model.is_cuda: + X, y, w = X.cuda(self.device, non_blocking=True), y.cuda(self.device, + non_blocking=True), w.cuda(self.device, + non_blocking=True) + + output = model(X) + loss = self.loss_func(output, y, w) + if phase == 'train': + optim.zero_grad() + loss.backward() + optim.step() + if i_batch % self.log_nth == 0: + self.logWriter.loss_per_iter(loss.item(), i_batch, current_iteration) + current_iteration += 1 + + loss_arr.append(loss.item()) + + _, batch_output = torch.max(output, dim=1) + out_list.append(batch_output.cpu()) + y_list.append(y.cpu()) + + del X, y, w, output, batch_output, loss + torch.cuda.empty_cache() + if phase == 'val': + if i_batch != len(dataloaders[phase]) - 1: + print("#", end='', flush=True) + else: + print("100%", flush=True) + + with torch.no_grad(): + out_arr, y_arr = torch.cat(out_list), torch.cat(y_list) + self.logWriter.loss_per_epoch(loss_arr, phase, epoch) + index = np.random.choice(len(dataloaders[phase].dataset.X), 3, replace=False) + self.logWriter.image_per_epoch(model.predict(dataloaders[phase].dataset.X[index], self.device), + dataloaders[phase].dataset.y[index], phase, epoch) + self.logWriter.cm_per_epoch(phase, out_arr, y_arr, epoch) + ds_mean = self.logWriter.dice_score_per_epoch(phase, out_arr, y_arr, epoch) + if phase == 'val': + if ds_mean > self.best_ds_mean: + self.best_ds_mean = ds_mean + self.best_ds_mean_epoch = epoch + + print("==== Epoch [" + str(epoch) + " / " + str(self.num_epochs) + "] DONE ====") + self.save_checkpoint({ + 'epoch': epoch + 1, + 'start_iteration': current_iteration + 1, + 'arch': self.model_name, + 'state_dict': model.state_dict(), + 'optimizer': optim.state_dict(), + 'scheduler': scheduler.state_dict(), + 'best_ds_mean': self.best_ds_mean, + 'best_ds_mean_epoch': self.best_ds_mean_epoch + }, os.path.join(self.exp_dir_path, CHECKPOINT_DIR, + 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION)) + + print('FINISH.') + self.logWriter.close() + + + def save_best_model(self, path): + """ + Save model with its parameters to the given path. Conventionally the + path should end with "*.model". + Inputs: + - path: path string + """ + print('Saving model... %s' % path) + print('Best Model at Epoch: ' + str(self.best_ds_mean_epoch)) + self.load_checkpoint(self.best_ds_mean_epoch) + + torch.save(self.model, path) + + def save_checkpoint(self, state, filename): + torch.save(state, filename) + + def load_checkpoint(self, epoch=None): + if epoch is not None: + checkpoint_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, + 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION) + self._load_checkpoint_file(checkpoint_path) + else: + all_files_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, '*.' + CHECKPOINT_EXTENSION) + list_of_files = glob.glob(all_files_path) + if len(list_of_files) > 0: + checkpoint_path = max(list_of_files, key=os.path.getctime) + self._load_checkpoint_file(checkpoint_path) + else: + self.logWriter.log( + "=> no checkpoint found at '{}' folder".format(os.path.join(self.exp_dir_path, CHECKPOINT_DIR))) + + def _load_checkpoint_file(self, file_path): + self.logWriter.log("=> loading checkpoint '{}'".format(file_path)) + checkpoint = torch.load(file_path) + self.start_epoch = checkpoint['epoch'] + self.start_iteration = checkpoint['start_iteration'] + self.model.load_state_dict(checkpoint['state_dict']) + self.optim.load_state_dict(checkpoint['optimizer']) + if 'best_ds_mean' in checkpoint.keys(): + self.best_ds_mean = checkpoint['best_ds_mean'] + self.best_ds_mean_epoch = checkpoint['best_ds_mean_epoch'] + + for state in self.optim.state.values(): + for k, v in state.items(): + if torch.is_tensor(v): + state[k] = v.to(self.device) + + self.scheduler.load_state_dict(checkpoint['scheduler']) + self.logWriter.log("=> loaded checkpoint '{}' (epoch {})".format(file_path, checkpoint['epoch'])) \ No newline at end of file