Diff of /tests/test_data.py [000000] .. [66326d]

Switch to unified view

a b/tests/test_data.py
1
"""
2
Test suite.
3
"""
4
5
import unittest
6
import itertools
7
import torch
8
9
from continual.src.utils import data_processing
10
11
BATCH_SIZES = (1, 10, 100)
12
SEQ_LENS = (4, 12, 48)
13
N_VARS = (2, 10, 100)
14
N_CLASSES = (2, 10)
15
N_LAYERS = (1, 2, 3, 4)
16
HIDDEN_SIZES = (32, 64, 128)
17
18
19
DEMOGRAPHICS = [
20
    "age",
21
    "gender",
22
    "ethnicity",
23
    "region",
24
    "time_year",
25
    "time_season",
26
    "time_month",
27
]
28
OUTCOMES = ["ARF", "shock", "mortality"]
29
DATASETS = ["MIMIC", "eICU"]
30
31
32
class TestDataLoadingMethods(unittest.TestCase):
33
    """
34
    Data loading tests.
35
    """
36
37
    def test_modalfeatvalfromseq(self):
38
        """
39
        Test that mode of correct dim is returned.
40
        """
41
        for n_samples in BATCH_SIZES:
42
            for seq_len in SEQ_LENS:
43
                for n_feats in N_VARS:
44
                    for i in range(n_feats):
45
                        sim_data = (
46
                            torch.randint(0, 1, (n_samples, seq_len, n_feats))
47
                            .clone()
48
                            .detach()
49
                            .numpy()
50
                        )
51
                        modes = data_processing.get_modes(sim_data, feat=i)
52
                        self.assertEqual(modes.shape, torch.Size([n_samples]))
53
54
55
# CL task split tests
56
class TestCLConstructionMethods(unittest.TestCase):
57
    """
58
    Test construction of Continual Learning task splits.
59
    """
60
61
    def ttest_taskidsnonoverlap(self):
62
        for dataset in DATASETS:
63
            for experiment in OUTCOMES:
64
                for demographic in DEMOGRAPHICS:
65
                    # JA: implement
66
                    tasks = data_processing.load_data(dataset, demographic, experiment)
67
                    for pair in itertools.combinations(tasks, repeat=2):
68
                        self.assertTrue(pair[0][:, 0].intersection(pair[0][:, 0]) == {})
69
70
    def ttest_tasktargets(self):
71
        for dataset in DATASETS:
72
            for experiment in OUTCOMES:
73
                for demographic in DEMOGRAPHICS:
74
                    # JA: implement
75
                    tasks = data_processing.load_data(dataset, demographic, experiment)
76
                    for task in tasks:
77
                        self.assertTrue(len(task[:, -1].unique()) == 2)
78
79
80
if __name__ == "__main__":
81
    unittest.main()