a b/trainers/trainers.py
1
from dataloaders.dataset1d import EcgDataset1D
2
from dataloaders.dataset2d import EcgDataset2D
3
from models import models1d, models2d
4
from trainers.base_trainer import BaseTrainer
5
6
7
class Trainer2D(BaseTrainer):
8
    def __init__(self, config):
9
        super().__init__(config)
10
11
    def _init_net(self):
12
        model = getattr(models2d, self.config["model"])(
13
            num_classes=self.config["num_classes"],
14
        )
15
        model = model.to(self.config["device"])
16
        return model
17
18
    def _init_dataloaders(self):
19
        train_loader = EcgDataset2D(
20
            self.config["train_json"], self.config["mapping_json"],
21
        ).get_dataloader(
22
            batch_size=self.config["batch_size"],
23
            num_workers=self.config["num_workers"],
24
        )
25
        val_loader = EcgDataset2D(
26
            self.config["val_json"], self.config["mapping_json"],
27
        ).get_dataloader(
28
            batch_size=self.config["batch_size"],
29
            num_workers=self.config["num_workers"],
30
        )
31
32
        return train_loader, val_loader
33
34
35
class Trainer1D(BaseTrainer):
36
    def __init__(self, config):
37
        super().__init__(config)
38
39
    def _init_net(self):
40
        model = getattr(models1d, self.config["model"])(
41
            num_classes=self.config["num_classes"],
42
        )
43
        model = model.to(self.config["device"])
44
        return model
45
46
    def _init_dataloaders(self):
47
        train_loader = EcgDataset1D(
48
            self.config["train_json"], self.config["mapping_json"],
49
        ).get_dataloader(
50
            batch_size=self.config["batch_size"],
51
            num_workers=self.config["num_workers"],
52
        )
53
        val_loader = EcgDataset1D(
54
            self.config["val_json"], self.config["mapping_json"],
55
        ).get_dataloader(
56
            batch_size=self.config["batch_size"],
57
            num_workers=self.config["num_workers"],
58
        )
59
60
        return train_loader, val_loader