Diff of /runners/runners.py [000000] .. [fbbdf8]

Switch to unified view

a b/runners/runners.py
1
from dataloaders.dataset1d import EcgDataset1D
2
from dataloaders.dataset2d import EcgDataset2D
3
from models import models1d, models2d
4
from runners.base_runner import BaseRunner
5
6
7
class Runner2D(BaseRunner):
8
    def __init__(self, config):
9
        super().__init__(config)
10
11
    def _init_net(self):
12
        model = getattr(models2d, self.config["model"])(
13
            num_classes=self.config["num_classes"],
14
        )
15
        model = model.to(self.config["device"])
16
        return model
17
18
    def _init_dataloader(self):
19
        inference_loader = EcgDataset2D(
20
            self.config["json"], self.config["mapping_json"],
21
        ).get_dataloader(
22
            batch_size=self.config["batch_size"],
23
            num_workers=self.config["num_workers"],
24
            shuffle=False,
25
        )
26
27
        return inference_loader
28
29
30
class Runner1D(BaseRunner):
31
    def __init__(self, config):
32
        super().__init__(config)
33
34
    def _init_net(self):
35
        model = getattr(models1d, self.config["model"])(
36
            num_classes=self.config["num_classes"],
37
        )
38
        model = model.to(self.config["device"])
39
        return model
40
41
    def _init_dataloader(self):
42
        inference_loader = EcgDataset1D(
43
            self.config["json"], self.config["mapping_json"],
44
        ).get_dataloader(
45
            batch_size=self.config["batch_size"],
46
            num_workers=self.config["num_workers"],
47
            shuffle=False,
48
        )
49
50
        return inference_loader