[c128d9]: / pipelines / pipelines.py

Download this file

25 lines (19 with data), 739 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from dataloaders.dataset1d import EcgPipelineDataset1D
from models import models1d
from pipelines.base_pipeline import BasePipeline
class Pipeline1D(BasePipeline):
def __init__(self, config):
super().__init__(config)
def _init_net(self):
model = getattr(models1d, self.config["model"])(
num_classes=self.config["num_classes"],
)
model = model.to(self.config["device"])
return model
def _init_dataloader(self):
inference_loader = EcgPipelineDataset1D(self.config["ecg_data"]).get_dataloader(
batch_size=self.config["batch_size"],
num_workers=self.config["num_workers"],
shuffle=False,
)
return inference_loader