Diff of /runners/base_runner.py [000000] .. [fbbdf8]

Switch to unified view

a b/runners/base_runner.py
1
import os
2
import os.path as osp
3
from datetime import datetime
4
5
import numpy as np
6
import torch
7
from tqdm import tqdm
8
9
from utils.network_utils import load_checkpoint
10
11
12
class BaseRunner:
13
    def __init__(self, config):
14
        self.config = config
15
        self.exp_name = self.config.get("exp_name", None)
16
        if self.exp_name is None:
17
            self.exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
18
19
        self.res_dir = osp.join(self.config["exp_dir"], self.exp_name, "results")
20
        os.makedirs(self.res_dir, exist_ok=True)
21
22
        self.model = self._init_net()
23
24
        self.inference_loader = self._init_dataloader()
25
26
        pretrained_path = self.config.get("model_path", False)
27
        if pretrained_path:
28
            load_checkpoint(pretrained_path, self.model)
29
        else:
30
            raise Exception(
31
                "model_path doesnt't exist in config. Please specify checkpoint path",
32
            )
33
34
    def _init_net(self):
35
        raise NotImplemented
36
37
    def _init_dataloader(self):
38
        raise NotImplemented
39
40
    def inference(self):
41
        self.model.eval()
42
43
        gt_class = np.empty(0)
44
        pd_class = np.empty(0)
45
46
        with torch.no_grad():
47
            for i, batch in tqdm(enumerate(self.inference_loader)):
48
                inputs = batch["image"].to(self.config["device"])
49
50
                predictions = self.model(inputs)
51
52
                classes = predictions.topk(k=1)[1].view(-1).cpu().numpy()
53
54
                gt_class = np.concatenate((gt_class, batch["class"].numpy()))
55
                pd_class = np.concatenate((pd_class, classes))
56
57
        class_accuracy = sum(pd_class == gt_class) / pd_class.shape[0]
58
        print("Validation CLASS accuracy - {:4f}".format(class_accuracy))
59
60
        pd_class = pd_class.astype(int)
61
        np.savetxt(osp.join(self.res_dir, "predictions.txt"), pd_class)