[95f789]: / src / runner.py

Download this file

25 lines (22 with data), 761 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from typing import Mapping, Any
import torch.nn as nn
from catalyst.dl.runner import SupervisedRunner, SupervisedWandbRunner
from catalyst.dl.core import RunnerState
from catalyst.contrib.optimizers import Lookahead
class ModelRunner(SupervisedWandbRunner):
def __init__(
self,
model: nn.Module = None,
device=None,
# input_key: str = ("images", "meta"),
input_key: str = "images",
output_key: str = "logits",
input_target_key: str = "targets",
):
super(ModelRunner, self).__init__(
model=model,
device=device,
input_key=input_key,
output_key=output_key,
input_target_key=input_target_key
)