--- a +++ b/utils/log_utils.py @@ -0,0 +1,155 @@ +import itertools +import logging +import os +import re +import shutil +from textwrap import wrap + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch +from tensorboardX import SummaryWriter + +import utils.evaluator as eu + +plt.switch_backend('agg') +plt.axis('scaled') + + +# TODO: Add custom phase names +class LogWriter(object): + def __init__(self, num_class, log_dir_name, exp_name, use_last_checkpoint=False, labels=None, + cm_cmap=plt.cm.Blues): + self.num_class = num_class + train_log_path, val_log_path = os.path.join(log_dir_name, exp_name, "train"), os.path.join(log_dir_name, + exp_name, + "val") + if not use_last_checkpoint: + if os.path.exists(train_log_path): + shutil.rmtree(train_log_path) + if os.path.exists(val_log_path): + shutil.rmtree(val_log_path) + + self.writer = { + 'train': SummaryWriter(train_log_path), + 'val': SummaryWriter(val_log_path) + } + self.curr_iter = 1 + self.cm_cmap = cm_cmap + self.labels = self.beautify_labels(labels) + self.logger = logging.getLogger() + file_handler = logging.FileHandler("{0}/{1}.log".format(os.path.join(log_dir_name, exp_name), "console_logs")) + self.logger.addHandler(file_handler) + + def log(self, text, phase='train'): + self.logger.info(text) + + def loss_per_iter(self, loss_value, i_batch, current_iteration): + print('[Iteration : ' + str(i_batch) + '] Loss -> ' + str(loss_value)) + self.writer['train'].add_scalar('loss/per_iteration', loss_value, current_iteration) + + def loss_per_epoch(self, loss_arr, phase, epoch): + if phase == 'train': + loss = loss_arr[-1] + else: + loss = np.mean(loss_arr) + self.writer[phase].add_scalar('loss/per_epoch', loss, epoch) + print('epoch ' + phase + ' loss = ' + str(loss)) + + def cm_per_epoch(self, phase, output, correct_labels, epoch): + print("Confusion Matrix...", end='', flush=True) + _, cm = eu.dice_confusion_matrix(output, correct_labels, self.num_class, mode=phase) + self.plot_cm('confusion_matrix', phase, cm, epoch) + print("DONE", flush=True) + + def plot_cm(self, caption, phase, cm, step=None): + fig = matplotlib.figure.Figure(figsize=(8, 8), dpi=180, facecolor='w', edgecolor='k') + ax = fig.add_subplot(1, 1, 1) + + ax.imshow(cm, interpolation='nearest', cmap=self.cm_cmap) + ax.set_xlabel('Predicted', fontsize=7) + ax.set_xticks(np.arange(self.num_class)) + c = ax.set_xticklabels(self.labels, fontsize=4, rotation=-90, ha='center') + ax.xaxis.set_label_position('bottom') + ax.xaxis.tick_bottom() + + ax.set_ylabel('True Label', fontsize=7) + ax.set_yticks(np.arange(self.num_class)) + ax.set_yticklabels(self.labels, fontsize=4, va='center') + ax.yaxis.set_label_position('left') + ax.yaxis.tick_left() + + thresh = cm.max() / 2. + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + ax.text(j, i, format(cm[i, j], '.2f') if cm[i, j] != 0 else '.', horizontalalignment="center", fontsize=6, + verticalalignment='center', color="white" if cm[i, j] > thresh else "black") + + fig.set_tight_layout(True) + np.set_printoptions(precision=2) + if step: + self.writer[phase].add_figure(caption + '/' + phase, fig, step) + else: + self.writer[phase].add_figure(caption + '/' + phase, fig) + + def dice_score_per_epoch(self, phase, output, correct_labels, epoch): + print("Dice Score...", end='', flush=True) + ds = eu.dice_score_perclass(output, correct_labels, self.num_class, mode=phase) + self.plot_dice_score(phase, 'dice_score_per_epoch', ds, 'Dice Score', epoch) + ds_mean = torch.mean(ds) + print("DONE", flush=True) + return ds_mean.item() + + def plot_dice_score(self, phase, caption, ds, title, step=None): + fig = matplotlib.figure.Figure(figsize=(8, 6), dpi=180, facecolor='w', edgecolor='k') + ax = fig.add_subplot(1, 1, 1) + ax.set_xlabel(title, fontsize=10) + ax.xaxis.set_label_position('top') + ax.bar(np.arange(self.num_class), ds) + ax.set_xticks(np.arange(self.num_class)) + c = ax.set_xticklabels(self.labels, fontsize=6, rotation=-90, ha='center') + ax.xaxis.tick_bottom() + if step: + self.writer[phase].add_figure(caption + '/' + phase, fig, step) + else: + self.writer[phase].add_figure(caption + '/' + phase, fig) + + def plot_eval_box_plot(self, caption, class_dist, title): + fig = matplotlib.figure.Figure(figsize=(8, 6), dpi=180, facecolor='w', edgecolor='k') + ax = fig.add_subplot(1, 1, 1) + ax.set_xlabel(title, fontsize=10) + ax.xaxis.set_label_position('top') + ax.boxplot(class_dist) + ax.set_xticks(np.arange(self.num_class)) + c = ax.set_xticklabels(self.labels, fontsize=6, rotation=-90, ha='center') + ax.xaxis.tick_bottom() + self.writer['val'].add_figure(caption, fig) + + def image_per_epoch(self, prediction, ground_truth, phase, epoch): + print("Sample Images...", end='', flush=True) + ncols = 2 + nrows = len(prediction) + fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 20)) + + for i in range(nrows): + ax[i][0].imshow(prediction[i], cmap='CMRmap', vmin=0, vmax=self.num_class - 1) + ax[i][0].set_title("Predicted", fontsize=10, color="blue") + ax[i][0].axis('off') + ax[i][1].imshow(ground_truth[i], cmap='CMRmap', vmin=0, vmax=self.num_class - 1) + ax[i][1].set_title("Ground Truth", fontsize=10, color="blue") + ax[i][1].axis('off') + fig.set_tight_layout(True) + self.writer[phase].add_figure('sample_prediction/' + phase, fig, epoch) + print('DONE', flush=True) + + def graph(self, model, X): + self.writer['train'].add_graph(model, X) + + def close(self): + self.writer['train'].close() + self.writer['val'].close() + + def beautify_labels(self, labels): + classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in labels] + classes = ['\n'.join(wrap(l, 40)) for l in classes] + return classes