Diff of /ecg_datamodule.py [000000] .. [134fd7]

Switch to unified view

a b/ecg_datamodule.py
1
import os
2
from typing import Optional, Sequence
3
from warnings import warn
4
5
import torch
6
from pytorch_lightning import LightningDataModule
7
from torch.utils.data import DataLoader, random_split
8
9
from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper
10
11
12
class ECGDataModule(LightningDataModule):
13
14
    name = 'ecg_dataset'
15
    extra_args = {}
16
17
    def __init__(
18
            self,
19
            config,
20
            transformations_str,
21
            t_params, 
22
            data_dir: str = None,
23
            val_split: int = 5000,
24
            num_workers: int = 16,
25
            batch_size: int = 32,
26
            seed: int = 42,
27
            *args,
28
            **kwargs,
29
    ):
30
        super().__init__(*args, **kwargs)
31
32
        self.dims = (12, 250)
33
        # self.val_split = val_split
34
        self.num_workers = num_workers
35
        self.batch_size = batch_size
36
        self.seed = seed
37
        self.data_dir = data_dir if data_dir is not None else os.getcwd()
38
        # self.num_samples = 60000 - val_split
39
40
        # self.DATASET = SimCLRDataSetWrapper(
41
        #    config['eval_batch_size'], **config['eval_dataset'])
42
        # self.train_loader, self.valid_loader = self.DATASET.get_data_loaders()
43
        self.config = config
44
        self.transformations_str = transformations_str
45
        self.t_params = t_params
46
        self.set_params()
47
48
    def set_params(self):
49
        dataset = SimCLRDataSetWrapper(
50
            self.config['batch_size'], **self.config['dataset'], transformations=self.transformations_str, t_params=self.t_params)
51
        train_loader, valid_loader = dataset.get_data_loaders() 
52
        self.num_samples = dataset.train_ds_size
53
        self.transformations = dataset.transformations
54
    @property
55
    def num_classes(self):
56
        """
57
        Return:
58
            10
59
        """
60
        return 5
61
62
    def prepare_data(self):
63
        pass
64
65
    def train_dataloader(self):
66
        dataset = SimCLRDataSetWrapper(
67
            self.config['batch_size'], **self.config['dataset'], transformations=self.transformations_str, t_params=self.t_params)
68
        train_loader, _ = dataset.get_data_loaders()
69
        return train_loader
70
71
    def val_dataloader(self):
72
        dataset = SimCLRDataSetWrapper(
73
            self.config['eval_batch_size'], **self.config['eval_dataset'], transformations=self.transformations_str, t_params=self.t_params)
74
        _, valid_loader_self = dataset.get_data_loaders()
75
        dataset = SimCLRDataSetWrapper(
76
            self.config['eval_batch_size'], **self.config['eval_dataset'], transformations=self.transformations_str, t_params=self.t_params, mode="linear_evaluation")
77
        valid_loader_sup, test_loader_sup = dataset.get_data_loaders()
78
        # return valid_loader
79
        return [valid_loader_self, valid_loader_sup, test_loader_sup]
80
81
82
    def test_dataloader(self):
83
        return self.valid_loader
84
85
    def default_transforms(self):
86
        pass