Diff of /tests.py [000000] .. [e6696a]

Switch to unified view

a b/tests.py
1
import unittest
2
from utils import Feats
3
from utils import SimulateRaw
4
from utils import PreProcess
5
from utils import CreateModel
6
from utils import TrainTestVal
7
from utils import LoadMuseData
8
from utils import FeatureEngineer
9
10
11
class ExampleTest(unittest.TestCase):
12
    """
13
    Our basic test class
14
    """
15
16
    def test_addition(self):
17
        """
18
        The actual test.
19
        Any method which starts with ``test_`` will considered as a test case.
20
        """
21
        res = 2 + 2
22
        self.assertEqual(res, 4)
23
24
    def test_feats(self):
25
        """
26
        Testing utils import.
27
        """
28
        feats = Feats()
29
        self.assertEqual(feats.num_classes, 2)
30
31
    def test_example_muse(self):
32
        """
33
        Testing example Muse code.
34
        """
35
36
        # Load Data
37
        raw = LoadMuseData(subs=[101, 102], nsesh=2, data_dir='visual/cueing')
38
39
        # Pre-Process EEG Data
40
        epochs = PreProcess(raw=raw, event_id={'LeftCue': 1, 'RightCue': 2})
41
42
        # Engineer Features for Model
43
        feats = FeatureEngineer(epochs=epochs)
44
45
        # Create Model
46
        model, _ = CreateModel(feats=feats)
47
48
        # Train with validation, then Test
49
        model, data = TrainTestVal(model=model,
50
                                   feats=feats,
51
                                   train_epochs=1,
52
                                   show_plots=False)
53
54
        self.assertLess(data['acc'], 1)
55
56
57
    def test_simulate_raw(self):
58
        """
59
        Testing simulated data pipeline.
60
        """
61
        # Simulate Data
62
        raw,event_id = SimulateRaw(amp1=50, amp2=60, freq=1.)
63
64
        # Pre-Process EEG Data
65
        epochs = PreProcess(raw,event_id)
66
67
        # Engineer Features for Model
68
        feats = FeatureEngineer(epochs)
69
70
        # Create Model
71
        model, _ = CreateModel(feats, units=[16,16])
72
73
        # Train with validation, then Test
74
        model, data = TrainTestVal(model,feats, 
75
                    train_epochs=1,show_plots=False)
76
77
        self.assertLess(data['acc'], 1)
78
    
79
    def test_frequencydomain_complex(self):
80
        """
81
        Testing simulated data pipeline.
82
        """
83
        # Simulate Data
84
        raw,event_id = SimulateRaw(amp1=50, amp2=60, freq=1.)
85
86
        # Pre-Process EEG Data
87
        epochs = PreProcess(raw,event_id)
88
89
        # Engineer Features for Model
90
        feats = FeatureEngineer(epochs,frequency_domain=True,
91
                                include_phase=True)
92
93
        # Create Model
94
        model, _ = CreateModel(feats, units=[16,16])
95
96
        # Train with validation, then Test
97
        model, data = TrainTestVal(model,feats, 
98
                    train_epochs=1,show_plots=False)
99
100
        self.assertLess(data['acc'], 1)
101
102
if __name__ == '__main__':
103
    unittest.main()