Diff of /src/runner.py [000000] .. [95f789]

Switch to unified view

a b/src/runner.py
1
from typing import Mapping, Any
2
import torch.nn as nn
3
from catalyst.dl.runner import SupervisedRunner, SupervisedWandbRunner
4
from catalyst.dl.core import RunnerState
5
from catalyst.contrib.optimizers import Lookahead
6
7
8
class ModelRunner(SupervisedWandbRunner):
9
    def __init__(
10
            self,
11
            model: nn.Module = None,
12
            device=None,
13
            # input_key: str = ("images", "meta"),
14
            input_key: str = "images",
15
            output_key: str = "logits",
16
            input_target_key: str = "targets",
17
    ):
18
        super(ModelRunner, self).__init__(
19
            model=model,
20
            device=device,
21
            input_key=input_key,
22
            output_key=output_key,
23
            input_target_key=input_target_key
24
        )