--- a
+++ b/bme1312/solver.py
@@ -0,0 +1,320 @@
+"""
+BME1312
+DO NOT MODIFY anything in this file.
+"""
+import os
+import itertools
+import statistics
+from typing import Callable
+from torch.utils.tensorboard import SummaryWriter
+import numpy as np
+
+# from tqdm import tqdm
+from matplotlib import pyplot as plt
+from tqdm.autonotebook import tqdm  # may raise warning about Jupyter
+from tqdm.auto import tqdm  # who needs warnings
+
+import torch, torchvision
+from torch import nn
+from torch.utils import data as Data
+
+from .utils import imgshow, imsshow, image_mask_overlay
+from .evaluation import get_accuracy, get_sensitivity, get_specificity, get_precision, get_F1, get_JS, get_DC
+
+
+class Solver(object):
+    def __init__(self,
+                 model: nn.Module,
+                 optimizer: torch.optim.Optimizer,
+                 criterion: Callable,
+                 lr_scheduler = None,
+                 recorder: dict = None,
+                 device=None):
+        device = device if device is not None else \
+            ('cuda:0' if torch.cuda.is_available() else 'cpu')
+        self.device = device
+        self.recorder = recorder
+        
+        self.model = self.to_device(model)
+        self.optimizer = optimizer
+        self.criterion = criterion
+        self.lr_scheduler = lr_scheduler        
+
+    def _step(self,
+              batch: tuple) -> dict:
+        raise NotImplementedError()
+
+    def to_device(self, x):
+        if isinstance(x, torch.Tensor):
+            return x.to(self.device)
+        elif isinstance(x, np.ndarray):
+            return torch.tensor(x, device=self.device)
+        elif isinstance(x, nn.Module):
+            return x.to(self.device)
+        else:
+            raise RuntimeError("Data cannot transfer to correct device.")
+
+    def to_numpy(self, x):
+        if isinstance(x, np.ndarray):
+            return x
+        elif isinstance(x, torch.Tensor):
+            return x.detach().cpu().numpy()
+        else:
+            raise RuntimeError(f"Cannot convert type {type(x)} into numpy array.")
+
+    def train(self,
+              epochs: int,
+              data_loader,
+              *,
+              val_loader=None,
+              save_path='./model.pth',
+              img_name='img',
+              is_plot=True) -> dict:             
+        torch.cuda.empty_cache()
+
+        writer = SummaryWriter()
+        val_loss_epochs = []
+        train_loss_epochs = []
+        pbar_train = tqdm(total=len(data_loader.sampler), unit='img')
+        if val_loader is not None:
+            pbar_val = tqdm(total=len(val_loader.sampler), desc=f'[Validation] waiting', unit='img')
+        for epoch in range(epochs):
+            pbar_train.reset()
+            pbar_train.set_description(desc=f'[Train] Epoch {epoch + 1}/{epochs}')
+            epoch_loss_acc = 0
+            epoch_size = 0
+            for batch in data_loader:
+                self.model.train()
+                # forward
+                step_dict = self._step(batch)
+                batch_size = step_dict['batch_size']
+                loss = step_dict['loss']
+
+                # backward
+                self.optimizer.zero_grad()
+                loss.backward()
+
+                # optimize
+                self.optimizer.step()
+
+                # update information
+                loss_value = loss.item()
+                epoch_loss_acc += loss_value
+                epoch_size += batch_size
+                pbar_train.update(batch_size)
+                pbar_train.set_postfix(loss=loss_value / batch_size)
+
+            epoch_avg_loss = epoch_loss_acc / epoch_size
+            pbar_train.set_postfix(epoch_avg_loss=epoch_avg_loss)
+            train_loss_epochs.append(epoch_avg_loss)
+            writer.add_scalar('Loss/train', epoch_avg_loss, epoch) 
+
+            if self.lr_scheduler:
+                self.lr_scheduler.step()
+
+            # validate if `val_loader` is specified
+            if val_loader is not None:
+                pbar_val.reset()
+                pbar_val.set_description(desc=f'[Validation] Epoch {epoch + 1}/{epochs}')
+                val_avg_loss = self.validate(val_loader, pbar=pbar_val, is_compute_metrics=False)
+                val_loss_epochs.append(val_avg_loss)
+                writer.add_scalar('Loss/val', val_avg_loss, epoch)  
+
+        pbar_train.close()
+        torch.save(self.model.state_dict(), save_path)
+        if val_loader is not None:
+            pbar_val.close()
+        train_loss_epochs = torch.tensor(train_loss_epochs).numpy()
+        val_loss_epochs = torch.tensor(val_loss_epochs).numpy()
+        plt.figure()
+        plt.plot(list(range(1, epochs + 1)), train_loss_epochs, label='train')
+        if val_loader is not None:
+            plt.plot(list(range(1, epochs + 1)), val_loss_epochs, label='validation')
+        plt.legend()
+        plt.xlabel('Epochs')
+        plt.ylabel('Loss')
+        img_name= img_name + '.png'
+        #save the plot
+        plt.savefig(img_name)
+        plt.show()
+        plt.close('all')
+
+        writer.close()  
+
+    def validate(self, data_loader, *, pbar=None, is_compute_metrics=True) -> float:
+        """
+        :param pbar: when pbar is specified, do not print average loss
+        :return:
+        """       
+        torch.cuda.empty_cache()
+
+        metrics_acc = {}
+        loss_acc = 0
+        size_acc = 0
+        is_need_log = (pbar is None)        
+        with torch.no_grad():
+            if pbar is None:
+                pbar = tqdm(total=len(data_loader.sampler), desc=f'[Validation]', unit='img')
+            for batch in data_loader:
+                self.model.eval()
+
+                # forward
+                step_dict = self._step(batch, is_compute_metrics=is_compute_metrics)
+                batch_size = step_dict['batch_size']
+                loss = step_dict['loss']
+                loss_value = loss.item()
+
+                # aggregate metrics
+                metrics_acc = self._aggregate_metrics(metrics_acc, step_dict)
+
+                # update information
+                loss_acc += loss_value
+                size_acc += batch_size
+                pbar.update(batch_size)
+                pbar.set_postfix(loss=loss_value)
+
+        val_avg_loss = loss_acc / size_acc
+        pbar.set_postfix(val_avg_loss=val_avg_loss)
+        if is_need_log:
+            pbar.close()  # destroy newly created pbar
+            print('=' * 30 + ' Measurements ' + '=' * 30)
+            for k, v in metrics_acc.items():
+                print(f"[{k}] {v / size_acc}")
+        else:
+            return val_avg_loss
+
+    def _aggregate_metrics(self, metrics_acc: dict, step_dict: dict):
+        batch_size = step_dict['batch_size']
+        for k, v in step_dict.items():
+            if k[:7] == 'metric_':
+                value = v * batch_size
+                metric_name = k[7:]
+                if metric_name not in metrics_acc:
+                    metrics_acc[metric_name] = value
+                else:
+                    metrics_acc[metric_name] += value
+        return metrics_acc
+
+    def visualize(self, data_loader, idx, net):
+        raise NotImplementedError()
+
+    def get_recorder(self) -> dict:
+        return self.recorder
+
+
+class Lab2Solver(Solver):
+    def _step(self, batch, is_compute_metrics=True) -> dict:
+        image, seg_gt = batch
+
+        image = self.to_device(image)  # [B, C=1, H, W]
+        seg_gt = self.to_device(seg_gt)  # [B, C=1, H, W]
+        B, C, H, W = image.shape
+
+        pred_seg = self.model(image)  # [B, C=1, H, W]
+        loss = self.criterion(pred_seg, seg_gt)
+
+        step_dict = {
+            'loss': loss,
+            'batch_size': B
+        }
+
+        # ============ compute metrics TODO
+        if not self.model.training and is_compute_metrics:
+            pred_seg_probs = torch.sigmoid(pred_seg)
+            SE = get_sensitivity(pred_seg_probs, seg_gt)
+            SP = get_specificity(pred_seg_probs, seg_gt)
+            PC = get_precision(pred_seg_probs, seg_gt)
+            F1 = get_F1(pred_seg_probs, seg_gt)
+            JS = get_JS(pred_seg_probs, seg_gt)
+            DC = get_DC(pred_seg_probs, seg_gt)
+
+            step_dict['metric_avg_Sensitivity'] = SE
+            step_dict['metric_avg_Specifity'] = SP
+            step_dict['metric_avg_Precision'] = PC
+            step_dict['metric_avg_F1Score'] = F1
+            step_dict['metric_avg_JaccardSimilarity'] = JS
+            step_dict['metric_avg_DiceCoefficient'] = DC
+
+        return step_dict
+
+    def visualize(self, data_loader, idx, *, dpi=100):        
+        with torch.no_grad():
+            # fetch data batch
+            if idx < 0 or idx > len(data_loader) * data_loader.batch_size:
+                raise RuntimeError("idx is out of range.")
+
+            batch_idx = idx // data_loader.batch_size
+            batch_offset = idx - batch_idx * data_loader.batch_size
+
+            batch = next(itertools.islice(data_loader, batch_idx, None))
+
+            # inference
+            image, seg_gt = batch
+
+            image = self.to_device(image)  # [B, C=1, H, W]
+            seg_gt = self.to_device(seg_gt)  # [B, C=1, H, W]
+            B, C, H, W = image.shape
+
+            self.model.eval()
+            pred_seg = self.model(image)  # [B, C=1, H, W]
+
+            pred_seg_probs = torch.sigmoid(pred_seg)
+            # pred_seg_mask = pred_seg_probs > 0.5  # default threshoulding: 0.5
+            # DC = get_DC(pred_seg_probs[batch_offset, ...][None, ...], seg_gt[batch_offset, ...][None, ...])
+            pred_seg_mask = torch.where(pred_seg_probs > 0.5, pred_seg_probs, torch.zeros_like(pred_seg_probs))
+            pred_seg_mask_argmax = 3 - torch.argmax(pred_seg_mask, dim=1, keepdim=True) + 1
+            mask = pred_seg_mask > 0
+            mask = torch.sum(mask, dim=1, keepdim=True)
+            mask = mask > 0
+            pred_seg_mask = torch.where(mask, pred_seg_mask_argmax, torch.zeros_like(pred_seg_mask_argmax))
+            pred_seg_probs = pred_seg_mask
+            DC = get_DC(pred_seg_probs, seg_gt)
+            
+            image = self.to_numpy(image[batch_offset, 0, :, :])
+            seg_gt = self.to_numpy(seg_gt[batch_offset, 0, :, :])
+            pred_seg_mask = self.to_numpy(pred_seg_mask[batch_offset, 0, :, :])
+
+        seg_gt = (seg_gt > 0.5) * 1.0
+        
+        seg_gt_overlay = image_mask_overlay(image, seg_gt)
+        pred_overlay = image_mask_overlay(image, pred_seg_mask)
+
+        imsshow([image, seg_gt, pred_seg_mask],
+                titles=['Image',
+                        f"Segmentation GT",
+                        f"Prediction (DICE {DC:.2f})"],
+                num_col=3,
+                dpi=dpi,
+                is_colorbar=True)
+        imsshow([seg_gt_overlay, pred_overlay],
+                titles=[f"Segmentation GT",
+                        f"Prediction (DICE {DC:.2f})"],
+                num_col=2,
+                dpi=dpi,
+                is_colorbar=False)
+    
+    def inference_all(self, data_loader, output_path) -> None:
+        torch.cuda.empty_cache()
+        
+        with torch.no_grad():
+            self.model.eval()
+            for batch in tqdm(data_loader):
+                image, filename = batch
+                B, C, H, W =image.shape
+
+                image = self.to_device(image)  # [B, C=1, H, W]
+
+                pred_seg = self.model(image)  # [B, C=1, H, W]
+
+                pred_seg_probs = torch.sigmoid(pred_seg)
+                pred_seg_mask = pred_seg_probs > 0.5  # default threshoulding: 0.5
+                pred_seg_mask = pred_seg_mask * 1.0
+                pred_seg_mask = pred_seg_mask.cpu() # [B, C=1, H, W]
+                
+                for i in range(B):
+                    torchvision.utils.save_image(
+                        pred_seg_mask[i, 0, :, :],
+                        os.path.join(output_path, f'case_{filename[i]}_segmentation.jpg')
+                        )
+    
+