Diff of /trainers/base_trainer.py [000000] .. [fbbdf8]

Switch to side-by-side view

--- a
+++ b/trainers/base_trainer.py
@@ -0,0 +1,153 @@
+import os
+import os.path as osp
+from datetime import datetime
+
+import numpy as np
+import torch
+from torch import nn, optim
+from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
+
+from utils.network_utils import load_checkpoint, save_checkpoint
+
+
+class BaseTrainer:
+    def __init__(self, config):
+        self.config = config
+        self.exp_name = self.config.get("exp_name", None)
+        if self.exp_name is None:
+            self.exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+
+        self.log_dir = osp.join(self.config["exp_dir"], self.exp_name, "logs")
+        self.pth_dir = osp.join(self.config["exp_dir"], self.exp_name, "checkpoints")
+        os.makedirs(self.log_dir, exist_ok=True)
+        os.makedirs(self.pth_dir, exist_ok=True)
+
+        self.writer = SummaryWriter(log_dir=self.log_dir)
+
+        self.model = self._init_net()
+        self.optimizer = self._init_optimizer()
+        self.criterion = nn.CrossEntropyLoss().to(self.config["device"])
+
+        self.train_loader, self.val_loader = self._init_dataloaders()
+
+        pretrained_path = self.config.get("model_path", False)
+        if pretrained_path:
+            self.training_epoch, self.total_iter = load_checkpoint(
+                pretrained_path, self.model, optimizer=self.optimizer,
+            )
+
+        else:
+            self.training_epoch = 0
+            self.total_iter = 0
+
+        self.epochs = self.config.get("epochs", int(1e5))
+
+    def _init_net(self):
+        raise NotImplemented
+
+    def _init_dataloaders(self):
+        raise NotImplemented
+
+    def _init_optimizer(self):
+        optimizer = getattr(optim, self.config["optim"])(
+            self.model.parameters(), **self.config["optim_params"]
+        )
+        return optimizer
+
+    def train_epoch(self):
+        self.model.train()
+        total_loss = 0
+
+        gt_class = np.empty(0)
+        pd_class = np.empty(0)
+
+        for i, batch in enumerate(self.train_loader):
+            inputs = batch["image"].to(self.config["device"])
+            targets = batch["class"].to(self.config["device"])
+
+            predictions = self.model(inputs)
+            loss = self.criterion(predictions, targets)
+
+            classes = predictions.topk(k=1)[1].view(-1).cpu().numpy()
+
+            gt_class = np.concatenate((gt_class, batch["class"].numpy()))
+            pd_class = np.concatenate((pd_class, classes))
+
+            total_loss += loss.item()
+
+            self.optimizer.zero_grad()
+            loss.backward()
+            self.optimizer.step()
+
+            if (i + 1) % 10 == 0:
+                print(
+                    "\tIter [%d/%d] Loss: %.4f"
+                    % (i + 1, len(self.train_loader), loss.item()),
+                )
+
+            self.writer.add_scalar(
+                "Train loss (iterations)", loss.item(), self.total_iter,
+            )
+            self.total_iter += 1
+
+        total_loss /= len(self.train_loader)
+        class_accuracy = sum(pd_class == gt_class) / pd_class.shape[0]
+
+        print("Train loss - {:4f}".format(total_loss))
+        print("Train CLASS accuracy - {:4f}".format(class_accuracy))
+
+        self.writer.add_scalar("Train loss (epochs)", total_loss, self.training_epoch)
+        self.writer.add_scalar(
+            "Train CLASS accuracy", class_accuracy, self.training_epoch,
+        )
+
+    def val(self):
+        self.model.eval()
+        total_loss = 0
+
+        gt_class = np.empty(0)
+        pd_class = np.empty(0)
+
+        with torch.no_grad():
+            for i, batch in tqdm(enumerate(self.val_loader)):
+                inputs = batch["image"].to(self.config["device"])
+                targets = batch["class"].to(self.config["device"])
+
+                predictions = self.model(inputs)
+                loss = self.criterion(predictions, targets)
+
+                classes = predictions.topk(k=1)[1].view(-1).cpu().numpy()
+
+                gt_class = np.concatenate((gt_class, batch["class"].numpy()))
+                pd_class = np.concatenate((pd_class, classes))
+
+                total_loss += loss.item()
+
+        total_loss /= len(self.val_loader)
+        class_accuracy = sum(pd_class == gt_class) / pd_class.shape[0]
+
+        print("Validation loss - {:4f}".format(total_loss))
+        print("Validation CLASS accuracy - {:4f}".format(class_accuracy))
+
+        self.writer.add_scalar("Validation loss", total_loss, self.training_epoch)
+        self.writer.add_scalar(
+            "Validation CLASS accuracy", class_accuracy, self.training_epoch,
+        )
+
+    def loop(self):
+        for epoch in range(self.training_epoch, self.epochs):
+            print("Epoch - {}".format(self.training_epoch + 1))
+            self.train_epoch()
+            save_checkpoint(
+                {
+                    "state_dict": self.model.state_dict(),
+                    "optimizer": self.optimizer.state_dict(),
+                    "epoch": epoch,
+                    "total_iter": self.total_iter,
+                },
+                osp.join(self.pth_dir, "{:0>8}.pth".format(epoch)),
+            )
+            self.val()
+
+            self.training_epoch += 1