Diff of /trainers/trainers.py [000000] .. [fbbdf8]

Switch to side-by-side view

--- a
+++ b/trainers/trainers.py
@@ -0,0 +1,60 @@
+from dataloaders.dataset1d import EcgDataset1D
+from dataloaders.dataset2d import EcgDataset2D
+from models import models1d, models2d
+from trainers.base_trainer import BaseTrainer
+
+
+class Trainer2D(BaseTrainer):
+    def __init__(self, config):
+        super().__init__(config)
+
+    def _init_net(self):
+        model = getattr(models2d, self.config["model"])(
+            num_classes=self.config["num_classes"],
+        )
+        model = model.to(self.config["device"])
+        return model
+
+    def _init_dataloaders(self):
+        train_loader = EcgDataset2D(
+            self.config["train_json"], self.config["mapping_json"],
+        ).get_dataloader(
+            batch_size=self.config["batch_size"],
+            num_workers=self.config["num_workers"],
+        )
+        val_loader = EcgDataset2D(
+            self.config["val_json"], self.config["mapping_json"],
+        ).get_dataloader(
+            batch_size=self.config["batch_size"],
+            num_workers=self.config["num_workers"],
+        )
+
+        return train_loader, val_loader
+
+
+class Trainer1D(BaseTrainer):
+    def __init__(self, config):
+        super().__init__(config)
+
+    def _init_net(self):
+        model = getattr(models1d, self.config["model"])(
+            num_classes=self.config["num_classes"],
+        )
+        model = model.to(self.config["device"])
+        return model
+
+    def _init_dataloaders(self):
+        train_loader = EcgDataset1D(
+            self.config["train_json"], self.config["mapping_json"],
+        ).get_dataloader(
+            batch_size=self.config["batch_size"],
+            num_workers=self.config["num_workers"],
+        )
+        val_loader = EcgDataset1D(
+            self.config["val_json"], self.config["mapping_json"],
+        ).get_dataloader(
+            batch_size=self.config["batch_size"],
+            num_workers=self.config["num_workers"],
+        )
+
+        return train_loader, val_loader