|
a |
|
b/runners/base_runner.py |
|
|
1 |
import os |
|
|
2 |
import os.path as osp |
|
|
3 |
from datetime import datetime |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
import torch |
|
|
7 |
from tqdm import tqdm |
|
|
8 |
|
|
|
9 |
from utils.network_utils import load_checkpoint |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
class BaseRunner: |
|
|
13 |
def __init__(self, config): |
|
|
14 |
self.config = config |
|
|
15 |
self.exp_name = self.config.get("exp_name", None) |
|
|
16 |
if self.exp_name is None: |
|
|
17 |
self.exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
18 |
|
|
|
19 |
self.res_dir = osp.join(self.config["exp_dir"], self.exp_name, "results") |
|
|
20 |
os.makedirs(self.res_dir, exist_ok=True) |
|
|
21 |
|
|
|
22 |
self.model = self._init_net() |
|
|
23 |
|
|
|
24 |
self.inference_loader = self._init_dataloader() |
|
|
25 |
|
|
|
26 |
pretrained_path = self.config.get("model_path", False) |
|
|
27 |
if pretrained_path: |
|
|
28 |
load_checkpoint(pretrained_path, self.model) |
|
|
29 |
else: |
|
|
30 |
raise Exception( |
|
|
31 |
"model_path doesnt't exist in config. Please specify checkpoint path", |
|
|
32 |
) |
|
|
33 |
|
|
|
34 |
def _init_net(self): |
|
|
35 |
raise NotImplemented |
|
|
36 |
|
|
|
37 |
def _init_dataloader(self): |
|
|
38 |
raise NotImplemented |
|
|
39 |
|
|
|
40 |
def inference(self): |
|
|
41 |
self.model.eval() |
|
|
42 |
|
|
|
43 |
gt_class = np.empty(0) |
|
|
44 |
pd_class = np.empty(0) |
|
|
45 |
|
|
|
46 |
with torch.no_grad(): |
|
|
47 |
for i, batch in tqdm(enumerate(self.inference_loader)): |
|
|
48 |
inputs = batch["image"].to(self.config["device"]) |
|
|
49 |
|
|
|
50 |
predictions = self.model(inputs) |
|
|
51 |
|
|
|
52 |
classes = predictions.topk(k=1)[1].view(-1).cpu().numpy() |
|
|
53 |
|
|
|
54 |
gt_class = np.concatenate((gt_class, batch["class"].numpy())) |
|
|
55 |
pd_class = np.concatenate((pd_class, classes)) |
|
|
56 |
|
|
|
57 |
class_accuracy = sum(pd_class == gt_class) / pd_class.shape[0] |
|
|
58 |
print("Validation CLASS accuracy - {:4f}".format(class_accuracy)) |
|
|
59 |
|
|
|
60 |
pd_class = pd_class.astype(int) |
|
|
61 |
np.savetxt(osp.join(self.res_dir, "predictions.txt"), pd_class) |