Switch to unified view

a b/tests/unit/openbci_test.py
1
import json
2
import time
3
import unittest
4
from mock import patch
5
6
from cloudbrain.modules.sources.openbci import OpenBCISource
7
from cloudbrain.publishers.rabbitmq import PikaPublisher
8
from cloudbrain.connectors.openbci import OpenBCIConnector, OpenBCISample
9
10
11
12
class MockSerial(object):
13
    """
14
    Patch class serial.Serial
15
    """
16
17
18
    def __init__(self, port, baud):
19
        pass
20
21
22
    def write(self, text):
23
        pass
24
25
26
    def read(self):
27
        return 'mock'
28
29
30
    def close(self):
31
        pass
32
33
34
    def inWaiting(self):
35
        return False
36
37
38
class MockHTTPResponse(object):
39
40
    status_code = 200
41
42
    def json(self):
43
        return {'vhost': 'mock_vhost'}
44
45
46
def mock_requests_get(vhost_info_url, verify):
47
    return MockHTTPResponse()
48
49
50
def mock_start(self, callback_functions):
51
    """
52
    Patch method OpenBCIConnector.start
53
54
    :param callback_functions: (dict)
55
      E.g: {metric_0: callback_0, ...,  metric_N: callback_N}
56
      where 'metric_X' is a string and 'callback_X' is a function that takes an
57
      Array as argument.
58
    """
59
60
    packet_id = 0
61
    channel_data = [i for i in range(8)]
62
    aux_data = []
63
    timestamp = int(time.time() * 1000000) # in microseconds
64
    sample = OpenBCISample(packet_id, channel_data, aux_data, timestamp)
65
    for (metric, callback_function) in callback_functions.items():
66
        callback_function(sample)
67
68
69
70
class MockChannel(object):
71
    def __init__(self):
72
        self.published_count = 0
73
        self.message_published = None
74
75
76
    def exchange_declare(self, exchange, exchange_type):
77
        pass
78
79
80
    def basic_publish(self, exchange, routing_key, body, properties):
81
        self.published_count += 1
82
        self.message_published = body
83
        print("Message count: %s" % self.published_count)
84
        print("Message body: %s" % body)
85
86
87
88
class MockBlockingConnection(object):
89
    """
90
    Patch pika's BlockingConnection
91
    """
92
93
94
    def __init__(self, params):
95
        pass
96
97
98
    def close(self):
99
        pass
100
101
102
    def channel(self):
103
        return MockChannel()
104
105
106
107
class OpenBCITest(unittest.TestCase):
108
    def setUp(self):
109
        self.device = 'openbci'
110
        self.user = "mock"
111
112
        # OpenBCI config
113
        self.port = None
114
        self.baud = 0
115
        self.filter_data = False
116
117
        # metric info
118
        self.metric_name = 'eeg'
119
        self.num_channels = 8
120
        self.buffer_size = 2
121
122
        # rmq info
123
        self.rabbitmq_user = 'mock_user'
124
        self.rabbitmq_pwd = 'mock_pwd'
125
126
        self.base_routing_key = '%s:%s' % (self.user, self.device)
127
128
        self.message = {'timestamp': 100}
129
        for i in range(self.num_channels):
130
            self.message['channel_%s' % i] = i
131
132
133
    def validate_start_method(self, sample):
134
        self.assertEqual(sample.channel_data, [i for i in range(self.num_channels)])
135
        print("OpenBCI started: %s" % sample.channel_data)
136
137
138
    @patch('serial.Serial', MockSerial)
139
    @patch('cloudbrain.connectors.openbci.OpenBCIConnector.start',
140
           mock_start)
141
    def test_OpenBCIConnector(self):
142
        board = OpenBCIConnector()
143
        callbacks = {self.metric_name: self.validate_start_method}
144
        board.start(callbacks)
145
146
147
    @patch('requests.get', mock_requests_get)
148
    @patch('pika.BlockingConnection', MockBlockingConnection)
149
    def test_PikaPublisher(self):
150
151
        options = {"rabbitmq_user": self.rabbitmq_user,
152
                   "rabbitmq_pwd": self.rabbitmq_pwd}
153
154
        publisher = PikaPublisher(self.base_routing_key, **options)
155
        publisher.connect()
156
        publisher.register(self.metric_name, self.num_channels,
157
                           self.buffer_size)
158
159
        for metric_name in publisher.metrics_to_num_channels().keys():
160
            routing_key = "%s:%s" % (self.base_routing_key, metric_name)
161
            self.assertEqual(publisher.channels[routing_key].published_count, 0)
162
            self.assertEqual(publisher.channels[routing_key].message_published,
163
                             None)
164
165
            publisher.publish(metric_name, self.message)
166
            self.assertEqual(publisher.channels[routing_key].published_count, 0)
167
168
            publisher.publish(metric_name, self.message)
169
            self.assertEqual(publisher.channels[routing_key].published_count, 1)
170
171
            expected_message = [{
172
                "channel_5": 5, "channel_4": 4, "channel_7": 7,
173
                "channel_6": 6, "channel_1": 1, "channel_0": 0,
174
                "channel_3": 3, "channel_2": 2, "timestamp": 100
175
            }, {
176
                "channel_5": 5, "channel_4": 4, "channel_7": 7,
177
                "channel_6": 6, "channel_1": 1, "channel_0": 0,
178
                "channel_3": 3, "channel_2": 2, "timestamp": 100
179
            }]
180
181
            published_message = json.loads(
182
                publisher.channels[routing_key].message_published)
183
            self.assertEqual(published_message, expected_message)
184
185
186
    @patch('requests.get', mock_requests_get)
187
    @patch('serial.Serial', MockSerial)
188
    @patch('pika.BlockingConnection', MockBlockingConnection)
189
    @patch('cloudbrain.connectors.openbci.OpenBCIConnector.start',
190
           mock_start)
191
    def test_OpenBCISource(self):
192
        options = {"rabbitmq_user": self.rabbitmq_user,
193
                   "rabbitmq_pwd": self.rabbitmq_pwd}
194
195
        publisher = PikaPublisher(self.base_routing_key, **options)
196
        publisher.connect()
197
        publisher.register(self.metric_name, self.num_channels,
198
                           self.buffer_size)
199
200
        publishers = [publisher]
201
        subscribers = []
202
        source = OpenBCISource(subscribers=subscribers,
203
                               publishers=publishers,
204
                               port=self.port,
205
                               baud=self.baud,
206
                               filter_data=self.filter_data)
207
        source.start()
208
209
        pub = source.publishers[0]
210
211
        routing_key = "%s:%s" % (self.base_routing_key, self.metric_name)
212
        self.assertEqual(pub.channels[routing_key].published_count, 0)
213
        self.assertEqual(pub.channels[routing_key].message_published,
214
                         None, "No messages should have been sent yet.")