|
a |
|
b/trainers/base_trainer.py |
|
|
1 |
import os |
|
|
2 |
import os.path as osp |
|
|
3 |
from datetime import datetime |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
import torch |
|
|
7 |
from torch import nn, optim |
|
|
8 |
from torch.utils.tensorboard import SummaryWriter |
|
|
9 |
from tqdm import tqdm |
|
|
10 |
|
|
|
11 |
from utils.network_utils import load_checkpoint, save_checkpoint |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
class BaseTrainer: |
|
|
15 |
def __init__(self, config): |
|
|
16 |
self.config = config |
|
|
17 |
self.exp_name = self.config.get("exp_name", None) |
|
|
18 |
if self.exp_name is None: |
|
|
19 |
self.exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
20 |
|
|
|
21 |
self.log_dir = osp.join(self.config["exp_dir"], self.exp_name, "logs") |
|
|
22 |
self.pth_dir = osp.join(self.config["exp_dir"], self.exp_name, "checkpoints") |
|
|
23 |
os.makedirs(self.log_dir, exist_ok=True) |
|
|
24 |
os.makedirs(self.pth_dir, exist_ok=True) |
|
|
25 |
|
|
|
26 |
self.writer = SummaryWriter(log_dir=self.log_dir) |
|
|
27 |
|
|
|
28 |
self.model = self._init_net() |
|
|
29 |
self.optimizer = self._init_optimizer() |
|
|
30 |
self.criterion = nn.CrossEntropyLoss().to(self.config["device"]) |
|
|
31 |
|
|
|
32 |
self.train_loader, self.val_loader = self._init_dataloaders() |
|
|
33 |
|
|
|
34 |
pretrained_path = self.config.get("model_path", False) |
|
|
35 |
if pretrained_path: |
|
|
36 |
self.training_epoch, self.total_iter = load_checkpoint( |
|
|
37 |
pretrained_path, self.model, optimizer=self.optimizer, |
|
|
38 |
) |
|
|
39 |
|
|
|
40 |
else: |
|
|
41 |
self.training_epoch = 0 |
|
|
42 |
self.total_iter = 0 |
|
|
43 |
|
|
|
44 |
self.epochs = self.config.get("epochs", int(1e5)) |
|
|
45 |
|
|
|
46 |
def _init_net(self): |
|
|
47 |
raise NotImplemented |
|
|
48 |
|
|
|
49 |
def _init_dataloaders(self): |
|
|
50 |
raise NotImplemented |
|
|
51 |
|
|
|
52 |
def _init_optimizer(self): |
|
|
53 |
optimizer = getattr(optim, self.config["optim"])( |
|
|
54 |
self.model.parameters(), **self.config["optim_params"] |
|
|
55 |
) |
|
|
56 |
return optimizer |
|
|
57 |
|
|
|
58 |
def train_epoch(self): |
|
|
59 |
self.model.train() |
|
|
60 |
total_loss = 0 |
|
|
61 |
|
|
|
62 |
gt_class = np.empty(0) |
|
|
63 |
pd_class = np.empty(0) |
|
|
64 |
|
|
|
65 |
for i, batch in enumerate(self.train_loader): |
|
|
66 |
inputs = batch["image"].to(self.config["device"]) |
|
|
67 |
targets = batch["class"].to(self.config["device"]) |
|
|
68 |
|
|
|
69 |
predictions = self.model(inputs) |
|
|
70 |
loss = self.criterion(predictions, targets) |
|
|
71 |
|
|
|
72 |
classes = predictions.topk(k=1)[1].view(-1).cpu().numpy() |
|
|
73 |
|
|
|
74 |
gt_class = np.concatenate((gt_class, batch["class"].numpy())) |
|
|
75 |
pd_class = np.concatenate((pd_class, classes)) |
|
|
76 |
|
|
|
77 |
total_loss += loss.item() |
|
|
78 |
|
|
|
79 |
self.optimizer.zero_grad() |
|
|
80 |
loss.backward() |
|
|
81 |
self.optimizer.step() |
|
|
82 |
|
|
|
83 |
if (i + 1) % 10 == 0: |
|
|
84 |
print( |
|
|
85 |
"\tIter [%d/%d] Loss: %.4f" |
|
|
86 |
% (i + 1, len(self.train_loader), loss.item()), |
|
|
87 |
) |
|
|
88 |
|
|
|
89 |
self.writer.add_scalar( |
|
|
90 |
"Train loss (iterations)", loss.item(), self.total_iter, |
|
|
91 |
) |
|
|
92 |
self.total_iter += 1 |
|
|
93 |
|
|
|
94 |
total_loss /= len(self.train_loader) |
|
|
95 |
class_accuracy = sum(pd_class == gt_class) / pd_class.shape[0] |
|
|
96 |
|
|
|
97 |
print("Train loss - {:4f}".format(total_loss)) |
|
|
98 |
print("Train CLASS accuracy - {:4f}".format(class_accuracy)) |
|
|
99 |
|
|
|
100 |
self.writer.add_scalar("Train loss (epochs)", total_loss, self.training_epoch) |
|
|
101 |
self.writer.add_scalar( |
|
|
102 |
"Train CLASS accuracy", class_accuracy, self.training_epoch, |
|
|
103 |
) |
|
|
104 |
|
|
|
105 |
def val(self): |
|
|
106 |
self.model.eval() |
|
|
107 |
total_loss = 0 |
|
|
108 |
|
|
|
109 |
gt_class = np.empty(0) |
|
|
110 |
pd_class = np.empty(0) |
|
|
111 |
|
|
|
112 |
with torch.no_grad(): |
|
|
113 |
for i, batch in tqdm(enumerate(self.val_loader)): |
|
|
114 |
inputs = batch["image"].to(self.config["device"]) |
|
|
115 |
targets = batch["class"].to(self.config["device"]) |
|
|
116 |
|
|
|
117 |
predictions = self.model(inputs) |
|
|
118 |
loss = self.criterion(predictions, targets) |
|
|
119 |
|
|
|
120 |
classes = predictions.topk(k=1)[1].view(-1).cpu().numpy() |
|
|
121 |
|
|
|
122 |
gt_class = np.concatenate((gt_class, batch["class"].numpy())) |
|
|
123 |
pd_class = np.concatenate((pd_class, classes)) |
|
|
124 |
|
|
|
125 |
total_loss += loss.item() |
|
|
126 |
|
|
|
127 |
total_loss /= len(self.val_loader) |
|
|
128 |
class_accuracy = sum(pd_class == gt_class) / pd_class.shape[0] |
|
|
129 |
|
|
|
130 |
print("Validation loss - {:4f}".format(total_loss)) |
|
|
131 |
print("Validation CLASS accuracy - {:4f}".format(class_accuracy)) |
|
|
132 |
|
|
|
133 |
self.writer.add_scalar("Validation loss", total_loss, self.training_epoch) |
|
|
134 |
self.writer.add_scalar( |
|
|
135 |
"Validation CLASS accuracy", class_accuracy, self.training_epoch, |
|
|
136 |
) |
|
|
137 |
|
|
|
138 |
def loop(self): |
|
|
139 |
for epoch in range(self.training_epoch, self.epochs): |
|
|
140 |
print("Epoch - {}".format(self.training_epoch + 1)) |
|
|
141 |
self.train_epoch() |
|
|
142 |
save_checkpoint( |
|
|
143 |
{ |
|
|
144 |
"state_dict": self.model.state_dict(), |
|
|
145 |
"optimizer": self.optimizer.state_dict(), |
|
|
146 |
"epoch": epoch, |
|
|
147 |
"total_iter": self.total_iter, |
|
|
148 |
}, |
|
|
149 |
osp.join(self.pth_dir, "{:0>8}.pth".format(epoch)), |
|
|
150 |
) |
|
|
151 |
self.val() |
|
|
152 |
|
|
|
153 |
self.training_epoch += 1 |