[fbbdf8]: / runners / runners.py

Download this file

51 lines (40 with data), 1.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from dataloaders.dataset1d import EcgDataset1D
from dataloaders.dataset2d import EcgDataset2D
from models import models1d, models2d
from runners.base_runner import BaseRunner
class Runner2D(BaseRunner):
def __init__(self, config):
super().__init__(config)
def _init_net(self):
model = getattr(models2d, self.config["model"])(
num_classes=self.config["num_classes"],
)
model = model.to(self.config["device"])
return model
def _init_dataloader(self):
inference_loader = EcgDataset2D(
self.config["json"], self.config["mapping_json"],
).get_dataloader(
batch_size=self.config["batch_size"],
num_workers=self.config["num_workers"],
shuffle=False,
)
return inference_loader
class Runner1D(BaseRunner):
def __init__(self, config):
super().__init__(config)
def _init_net(self):
model = getattr(models1d, self.config["model"])(
num_classes=self.config["num_classes"],
)
model = model.to(self.config["device"])
return model
def _init_dataloader(self):
inference_loader = EcgDataset1D(
self.config["json"], self.config["mapping_json"],
).get_dataloader(
batch_size=self.config["batch_size"],
num_workers=self.config["num_workers"],
shuffle=False,
)
return inference_loader