Diff of /tests/test_unit.py [000000] .. [66326d]

Switch to unified view

a b/tests/test_unit.py
1
"""
2
Test suite.
3
"""
4
5
import math
6
import unittest
7
import torch
8
9
from continual.src.utils import models, data_processing
10
11
BATCH_SIZES = (1, 10, 100)
12
SEQ_LENS = (4, 12, 48)
13
N_VARS = (2, 10, 100)
14
N_CLASSES = (2, 10)
15
N_LAYERS = (1, 2, 3, 4)
16
HIDDEN_SIZES = (32, 64, 128)
17
18
19
DEMOGRAPHICS = [
20
    "age",
21
    "gender",
22
    "ethnicity",
23
    "region",
24
    "time_year",
25
    "time_season",
26
    "time_month",
27
]
28
OUTCOMES = ["ARF", "shock", "mortality"]
29
DATASETS = ["MIMIC", "eICU"]
30
31
32
def magnitude(value):
33
    """
34
    Return the magnitude of a positive number.
35
    """
36
    if value < 0:
37
        raise ValueError
38
    if value == 0:
39
        return 0
40
    else:
41
        return int(math.floor(math.log10(value)))
42
43
44
class TestModelMethods(unittest.TestCase):
45
    """
46
    Model definition tests.
47
    """
48
49
    def test_modeloutputshape(self):
50
        """
51
        Testing model produces correct shape of output for variety of input sizes.
52
        """
53
        for batch_size in BATCH_SIZES:
54
            for seq_len in SEQ_LENS:
55
                for n_vars in N_VARS:
56
                    for n_classes in N_CLASSES:
57
                        for n_layers in N_LAYERS:
58
                            for hidden_size in HIDDEN_SIZES:
59
                                batch = torch.randn(batch_size, seq_len, n_vars)
60
                                simple_models = models.MODELS.values()
61
                                for model in simple_models:
62
                                    model = model(
63
                                        seq_len=seq_len,
64
                                        n_channels=n_vars,
65
                                        hidden_dim=hidden_size,
66
                                        output_size=n_classes,
67
                                        n_layers=n_layers,
68
                                    )
69
                                    # Set in eval mode to avoid batch-norm error when subtracting mean from val training on 1 datapoint
70
                                    model.eval()
71
                                    output = model(batch)
72
                                    expected_shape = torch.Size([batch_size, n_classes])
73
                                    self.assertEqual(output.shape, expected_shape)
74
75
    def ttest_modelcapacity(self):
76
        """
77
        JA: Need to update given parameterisation of model structure.
78
79
        Testing different models have same order of magnitude of parameters.
80
        """
81
        for seq_len in SEQ_LENS:
82
            for n_vars in N_VARS:
83
                for n_classes in N_CLASSES:
84
                    simple_models = models.MODELS.values()
85
                    n_params = [
86
                        sum(
87
                            p.numel()
88
                            for p in m(
89
                                seq_len=seq_len,
90
                                n_channels=n_vars,
91
                                output_size=n_classes,
92
                            ).parameters()
93
                            if p.requires_grad
94
                        )
95
                        for m in simple_models
96
                    ]
97
                    param_magnitudes = [magnitude(p) for p in n_params]
98
                    # RNN/LSTM order bigger
99
                    self.assertTrue(max(param_magnitudes) - min(param_magnitudes) <= 1)
100
101
    # JA: Implement test to check params passed by config actually change model structure.
102
103
104
class TestDataLoadingMethods(unittest.TestCase):
105
    """
106
    Data loading tests.
107
    """
108
109
    def test_modalfeatvalfromseq(self):
110
        """
111
        Test that mode of correct dim is returned.
112
        """
113
        for n_samples in BATCH_SIZES:
114
            for seq_len in SEQ_LENS:
115
                for n_feats in N_VARS:
116
                    for i in range(n_feats):
117
                        sim_data = (
118
                            torch.randint(0, 1, (n_samples, seq_len, n_feats))
119
                            .clone()
120
                            .detach()
121
                            .numpy()
122
                        )
123
                        modes = data_processing.get_modes(sim_data, feat=i)
124
                        self.assertEqual(modes.shape, torch.Size([n_samples]))
125
126
127
if __name__ == "__main__":
128
    unittest.main()