|
a |
|
b/pipelines/pipelines.py |
|
|
1 |
from dataloaders.dataset1d import EcgPipelineDataset1D |
|
|
2 |
from models import models1d |
|
|
3 |
from pipelines.base_pipeline import BasePipeline |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class Pipeline1D(BasePipeline): |
|
|
7 |
def __init__(self, config): |
|
|
8 |
super().__init__(config) |
|
|
9 |
|
|
|
10 |
def _init_net(self): |
|
|
11 |
model = getattr(models1d, self.config["model"])( |
|
|
12 |
num_classes=self.config["num_classes"], |
|
|
13 |
) |
|
|
14 |
model = model.to(self.config["device"]) |
|
|
15 |
return model |
|
|
16 |
|
|
|
17 |
def _init_dataloader(self): |
|
|
18 |
inference_loader = EcgPipelineDataset1D(self.config["ecg_data"]).get_dataloader( |
|
|
19 |
batch_size=self.config["batch_size"], |
|
|
20 |
num_workers=self.config["num_workers"], |
|
|
21 |
shuffle=False, |
|
|
22 |
) |
|
|
23 |
|
|
|
24 |
return inference_loader |