Diff of /src/runner.py [000000] .. [95f789]

Switch to side-by-side view

--- a
+++ b/src/runner.py
@@ -0,0 +1,24 @@
+from typing import Mapping, Any
+import torch.nn as nn
+from catalyst.dl.runner import SupervisedRunner, SupervisedWandbRunner
+from catalyst.dl.core import RunnerState
+from catalyst.contrib.optimizers import Lookahead
+
+
+class ModelRunner(SupervisedWandbRunner):
+    def __init__(
+            self,
+            model: nn.Module = None,
+            device=None,
+            # input_key: str = ("images", "meta"),
+            input_key: str = "images",
+            output_key: str = "logits",
+            input_target_key: str = "targets",
+    ):
+        super(ModelRunner, self).__init__(
+            model=model,
+            device=device,
+            input_key=input_key,
+            output_key=output_key,
+            input_target_key=input_target_key
+        )