Diff of /runners/base_runner.py [000000] .. [fbbdf8]

Switch to side-by-side view

--- a
+++ b/runners/base_runner.py
@@ -0,0 +1,61 @@
+import os
+import os.path as osp
+from datetime import datetime
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from utils.network_utils import load_checkpoint
+
+
+class BaseRunner:
+    def __init__(self, config):
+        self.config = config
+        self.exp_name = self.config.get("exp_name", None)
+        if self.exp_name is None:
+            self.exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+
+        self.res_dir = osp.join(self.config["exp_dir"], self.exp_name, "results")
+        os.makedirs(self.res_dir, exist_ok=True)
+
+        self.model = self._init_net()
+
+        self.inference_loader = self._init_dataloader()
+
+        pretrained_path = self.config.get("model_path", False)
+        if pretrained_path:
+            load_checkpoint(pretrained_path, self.model)
+        else:
+            raise Exception(
+                "model_path doesnt't exist in config. Please specify checkpoint path",
+            )
+
+    def _init_net(self):
+        raise NotImplemented
+
+    def _init_dataloader(self):
+        raise NotImplemented
+
+    def inference(self):
+        self.model.eval()
+
+        gt_class = np.empty(0)
+        pd_class = np.empty(0)
+
+        with torch.no_grad():
+            for i, batch in tqdm(enumerate(self.inference_loader)):
+                inputs = batch["image"].to(self.config["device"])
+
+                predictions = self.model(inputs)
+
+                classes = predictions.topk(k=1)[1].view(-1).cpu().numpy()
+
+                gt_class = np.concatenate((gt_class, batch["class"].numpy()))
+                pd_class = np.concatenate((pd_class, classes))
+
+        class_accuracy = sum(pd_class == gt_class) / pd_class.shape[0]
+        print("Validation CLASS accuracy - {:4f}".format(class_accuracy))
+
+        pd_class = pd_class.astype(int)
+        np.savetxt(osp.join(self.res_dir, "predictions.txt"), pd_class)