a b/trainers/base_trainer.py
1
import os
2
import os.path as osp
3
from datetime import datetime
4
5
import numpy as np
6
import torch
7
from torch import nn, optim
8
from torch.utils.tensorboard import SummaryWriter
9
from tqdm import tqdm
10
11
from utils.network_utils import load_checkpoint, save_checkpoint
12
13
14
class BaseTrainer:
15
    def __init__(self, config):
16
        self.config = config
17
        self.exp_name = self.config.get("exp_name", None)
18
        if self.exp_name is None:
19
            self.exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
20
21
        self.log_dir = osp.join(self.config["exp_dir"], self.exp_name, "logs")
22
        self.pth_dir = osp.join(self.config["exp_dir"], self.exp_name, "checkpoints")
23
        os.makedirs(self.log_dir, exist_ok=True)
24
        os.makedirs(self.pth_dir, exist_ok=True)
25
26
        self.writer = SummaryWriter(log_dir=self.log_dir)
27
28
        self.model = self._init_net()
29
        self.optimizer = self._init_optimizer()
30
        self.criterion = nn.CrossEntropyLoss().to(self.config["device"])
31
32
        self.train_loader, self.val_loader = self._init_dataloaders()
33
34
        pretrained_path = self.config.get("model_path", False)
35
        if pretrained_path:
36
            self.training_epoch, self.total_iter = load_checkpoint(
37
                pretrained_path, self.model, optimizer=self.optimizer,
38
            )
39
40
        else:
41
            self.training_epoch = 0
42
            self.total_iter = 0
43
44
        self.epochs = self.config.get("epochs", int(1e5))
45
46
    def _init_net(self):
47
        raise NotImplemented
48
49
    def _init_dataloaders(self):
50
        raise NotImplemented
51
52
    def _init_optimizer(self):
53
        optimizer = getattr(optim, self.config["optim"])(
54
            self.model.parameters(), **self.config["optim_params"]
55
        )
56
        return optimizer
57
58
    def train_epoch(self):
59
        self.model.train()
60
        total_loss = 0
61
62
        gt_class = np.empty(0)
63
        pd_class = np.empty(0)
64
65
        for i, batch in enumerate(self.train_loader):
66
            inputs = batch["image"].to(self.config["device"])
67
            targets = batch["class"].to(self.config["device"])
68
69
            predictions = self.model(inputs)
70
            loss = self.criterion(predictions, targets)
71
72
            classes = predictions.topk(k=1)[1].view(-1).cpu().numpy()
73
74
            gt_class = np.concatenate((gt_class, batch["class"].numpy()))
75
            pd_class = np.concatenate((pd_class, classes))
76
77
            total_loss += loss.item()
78
79
            self.optimizer.zero_grad()
80
            loss.backward()
81
            self.optimizer.step()
82
83
            if (i + 1) % 10 == 0:
84
                print(
85
                    "\tIter [%d/%d] Loss: %.4f"
86
                    % (i + 1, len(self.train_loader), loss.item()),
87
                )
88
89
            self.writer.add_scalar(
90
                "Train loss (iterations)", loss.item(), self.total_iter,
91
            )
92
            self.total_iter += 1
93
94
        total_loss /= len(self.train_loader)
95
        class_accuracy = sum(pd_class == gt_class) / pd_class.shape[0]
96
97
        print("Train loss - {:4f}".format(total_loss))
98
        print("Train CLASS accuracy - {:4f}".format(class_accuracy))
99
100
        self.writer.add_scalar("Train loss (epochs)", total_loss, self.training_epoch)
101
        self.writer.add_scalar(
102
            "Train CLASS accuracy", class_accuracy, self.training_epoch,
103
        )
104
105
    def val(self):
106
        self.model.eval()
107
        total_loss = 0
108
109
        gt_class = np.empty(0)
110
        pd_class = np.empty(0)
111
112
        with torch.no_grad():
113
            for i, batch in tqdm(enumerate(self.val_loader)):
114
                inputs = batch["image"].to(self.config["device"])
115
                targets = batch["class"].to(self.config["device"])
116
117
                predictions = self.model(inputs)
118
                loss = self.criterion(predictions, targets)
119
120
                classes = predictions.topk(k=1)[1].view(-1).cpu().numpy()
121
122
                gt_class = np.concatenate((gt_class, batch["class"].numpy()))
123
                pd_class = np.concatenate((pd_class, classes))
124
125
                total_loss += loss.item()
126
127
        total_loss /= len(self.val_loader)
128
        class_accuracy = sum(pd_class == gt_class) / pd_class.shape[0]
129
130
        print("Validation loss - {:4f}".format(total_loss))
131
        print("Validation CLASS accuracy - {:4f}".format(class_accuracy))
132
133
        self.writer.add_scalar("Validation loss", total_loss, self.training_epoch)
134
        self.writer.add_scalar(
135
            "Validation CLASS accuracy", class_accuracy, self.training_epoch,
136
        )
137
138
    def loop(self):
139
        for epoch in range(self.training_epoch, self.epochs):
140
            print("Epoch - {}".format(self.training_epoch + 1))
141
            self.train_epoch()
142
            save_checkpoint(
143
                {
144
                    "state_dict": self.model.state_dict(),
145
                    "optimizer": self.optimizer.state_dict(),
146
                    "epoch": epoch,
147
                    "total_iter": self.total_iter,
148
                },
149
                osp.join(self.pth_dir, "{:0>8}.pth".format(epoch)),
150
            )
151
            self.val()
152
153
            self.training_epoch += 1