Diff of /pipelines/pipelines.py [000000] .. [fbbdf8]

Switch to unified view

a b/pipelines/pipelines.py
1
from dataloaders.dataset1d import EcgPipelineDataset1D
2
from models import models1d
3
from pipelines.base_pipeline import BasePipeline
4
5
6
class Pipeline1D(BasePipeline):
7
    def __init__(self, config):
8
        super().__init__(config)
9
10
    def _init_net(self):
11
        model = getattr(models1d, self.config["model"])(
12
            num_classes=self.config["num_classes"],
13
        )
14
        model = model.to(self.config["device"])
15
        return model
16
17
    def _init_dataloader(self):
18
        inference_loader = EcgPipelineDataset1D(self.config["ecg_data"]).get_dataloader(
19
            batch_size=self.config["batch_size"],
20
            num_workers=self.config["num_workers"],
21
            shuffle=False,
22
        )
23
24
        return inference_loader