[fbbdf8]: / runners / base_runner.py

Download this file

62 lines (43 with data), 1.9 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import os.path as osp
from datetime import datetime
import numpy as np
import torch
from tqdm import tqdm
from utils.network_utils import load_checkpoint
class BaseRunner:
def __init__(self, config):
self.config = config
self.exp_name = self.config.get("exp_name", None)
if self.exp_name is None:
self.exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.res_dir = osp.join(self.config["exp_dir"], self.exp_name, "results")
os.makedirs(self.res_dir, exist_ok=True)
self.model = self._init_net()
self.inference_loader = self._init_dataloader()
pretrained_path = self.config.get("model_path", False)
if pretrained_path:
load_checkpoint(pretrained_path, self.model)
else:
raise Exception(
"model_path doesnt't exist in config. Please specify checkpoint path",
)
def _init_net(self):
raise NotImplemented
def _init_dataloader(self):
raise NotImplemented
def inference(self):
self.model.eval()
gt_class = np.empty(0)
pd_class = np.empty(0)
with torch.no_grad():
for i, batch in tqdm(enumerate(self.inference_loader)):
inputs = batch["image"].to(self.config["device"])
predictions = self.model(inputs)
classes = predictions.topk(k=1)[1].view(-1).cpu().numpy()
gt_class = np.concatenate((gt_class, batch["class"].numpy()))
pd_class = np.concatenate((pd_class, classes))
class_accuracy = sum(pd_class == gt_class) / pd_class.shape[0]
print("Validation CLASS accuracy - {:4f}".format(class_accuracy))
pd_class = pd_class.astype(int)
np.savetxt(osp.join(self.res_dir, "predictions.txt"), pd_class)