Diff of /train.py [000000] .. [4b8af8]

Switch to side-by-side view

--- a
+++ b/train.py
@@ -0,0 +1,202 @@
+import os
+
+import numpy as np
+import pandas as pd
+import cv2
+
+import matplotlib.pyplot as plt
+
+from sklearn.model_selection import train_test_split
+
+import os
+import time
+
+import torch
+import torch.nn as nn
+from torch.optim import Adam
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+
+from dataset_dataloader import LungsDataset, get_augmentations, get_dataloader
+from loss_metric import Meter, BCEDiceLoss
+
+from segmentation_models_pytorch.unet import Unet
+import matplotlib.pyplot as plt
+from IPython.display import clear_output
+
+
+class Trainer:
+    """
+    Factory for training proccess.
+    Args:
+        display_plot: if True - plot train history after each epoch.
+        net: neural network for mask prediction.
+        criterion: factory for calculating objective loss.
+        optimizer: optimizer for weights updating.
+        phases: list with train and validation phases.
+        dataloaders: dict with data loaders for train and val phases.
+        imgs_dir: path to folder with images.
+        masks_dir: path to folder with imasks.
+        path_to_csv: path to csv file.
+        meter: factory for storing and updating metrics.
+        batch_size: data batch size for one step weights updating.
+        num_epochs: num weights updation for all data.
+        accumulation_steps: the number of steps after which the optimization step can be taken
+                    (https://www.kaggle.com/c/understanding_cloud_organization/discussion/105614).
+        lr: learning rate for optimizer.
+        scheduler: scheduler for control learning rate.
+        losses: dict for storing lists with losses for each phase.
+        jaccard_scores: dict for storing lists with jaccard scores for each phase.
+        dice_scores: dict for storing lists with dice scores for each phase.
+    """
+    def __init__(self,
+                 net: nn.Module,
+                 criterion: nn.Module,
+                 lr: float,
+                 accumulation_steps: int,
+                 batch_size: int,
+                 num_epochs: int,
+                 imgs_dir: str,
+                 masks_dir: str,
+                 path_to_csv: str,
+                 display_plot: bool = True
+                ):
+
+        """Initialization."""
+        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+        print("device:", self.device)
+        self.display_plot = display_plot
+        self.net = net
+        self.net = self.net.to(self.device)
+        self.criterion = criterion
+        self.optimizer = Adam(self.net.parameters(), lr=lr)
+        self.scheduler = ReduceLROnPlateau(self.optimizer, mode="min",
+                                           patience=3, verbose=True)
+        self.accumulation_steps = accumulation_steps // batch_size
+        self.phases = ["train", "val"]
+        self.num_epochs = num_epochs
+
+        self.dataloaders = {
+            phase: get_dataloader(
+                imgs_dir = imgs_dir,
+                masks_dir = masks_dir,
+                path_to_csv = path_to_csv,
+                phase = phase,
+                batch_size = 8,
+                num_workers = 6
+            )
+            for phase in self.phases
+        }
+        self.best_loss = float("inf")
+        self.losses = {phase: [] for phase in self.phases}
+        self.dice_scores = {phase: [] for phase in self.phases}
+        self.jaccard_scores = {phase: [] for phase in self.phases}
+         
+    def _compute_loss_and_outputs(self,
+                                  images: torch.Tensor,
+                                  targets: torch.Tensor):
+        images = images.to(self.device)
+        targets = targets.to(self.device)
+        logits = self.net(images)
+        loss = self.criterion(logits, targets)
+        return loss, logits
+        
+    def _do_epoch(self, epoch: int, phase: str):
+        print(f"{phase} epoch: {epoch} | time: {time.strftime('%H:%M:%S')}")
+
+        self.net.train() if phase == "train" else self.net.eval()
+        meter = Meter()
+        dataloader = self.dataloaders[phase]
+        total_batches = len(dataloader)
+        running_loss = 0.0
+        self.optimizer.zero_grad()
+        for itr, (images, targets) in enumerate(dataloader):
+            loss, logits = self._compute_loss_and_outputs(images, targets)
+            loss = loss / self.accumulation_steps
+            if phase == "train":
+                loss.backward()
+                if (itr + 1) % self.accumulation_steps == 0:
+                    self.optimizer.step()
+                    self.optimizer.zero_grad()
+            running_loss += loss.item()
+            meter.update(logits.detach().cpu(),
+                         targets.detach().cpu()
+                        )
+            
+        epoch_loss = (running_loss * self.accumulation_steps) / total_batches
+        epoch_dice, epoch_iou = meter.get_metrics()
+        
+        self.losses[phase].append(epoch_loss)
+        self.dice_scores[phase].append(epoch_dice)
+        self.jaccard_scores[phase].append(epoch_iou)
+
+        return epoch_loss
+        
+    def train(self):
+        for epoch in range(self.num_epochs):
+            self._do_epoch(epoch, "train")
+            with torch.no_grad():
+                val_loss = self._do_epoch(epoch, "val")
+                self.scheduler.step(val_loss)
+            if self.display_plot:
+                self._plot_train_history()
+                
+            if val_loss < self.best_loss:
+                print(f"\n{'#'*20}\nSaved new checkpoint\n{'#'*20}\n")
+                self.best_loss = val_loss
+                torch.save(self.net.state_dict(), "best_model.pth")
+            print()
+        self._save_train_history()
+            
+    def _plot_train_history(self):
+        data = [self.losses, self.dice_scores, self.jaccard_scores]
+        colors = ['deepskyblue', "crimson"]
+        labels = [
+            f"""
+            train loss {self.losses['train'][-1]}
+            val loss {self.losses['val'][-1]}
+            """,
+            
+            f"""
+            train dice score {self.dice_scores['train'][-1]}
+            val dice score {self.dice_scores['val'][-1]} 
+            """, 
+                  
+            f"""
+            train jaccard score {self.jaccard_scores['train'][-1]}
+            val jaccard score {self.jaccard_scores['val'][-1]}
+            """,
+        ]
+        
+        clear_output(True)
+        with plt.style.context("seaborn-dark-palette"):
+            fig, axes = plt.subplots(3, 1, figsize=(8, 10))
+            for i, ax in enumerate(axes):
+                ax.plot(data[i]['val'], c=colors[0], label="val")
+                ax.plot(data[i]['train'], c=colors[-1], label="train")
+                ax.set_title(labels[i])
+                ax.legend(loc="upper right")
+                
+            plt.tight_layout()
+            plt.show()
+            
+    def load_predtrain_model(self,
+                             state_path: str):
+        self.net.load_state_dict(torch.load(state_path))
+        print("Predtrain model loaded")
+        
+    def _save_train_history(self):
+        """writing model weights and training logs to files."""
+        torch.save(self.net.state_dict(),
+                   f"last_epoch_model.pth")
+
+        logs_ = [self.losses, self.dice_scores, self.jaccard_scores]
+        log_names_ = ["_loss", "_dice", "_jaccard"]
+        logs = [logs_[i][key] for i in list(range(len(logs_)))
+                         for key in logs_[i]]
+        log_names = [key+log_names_[i] 
+                     for i in list(range(len(logs_))) 
+                     for key in logs_[i]
+                    ]
+        pd.DataFrame(
+            dict(zip(log_names, logs))
+        ).to_csv("train_log.csv", index=False)