|
a |
|
b/tests/unit/openbci_test.py |
|
|
1 |
import json |
|
|
2 |
import time |
|
|
3 |
import unittest |
|
|
4 |
from mock import patch |
|
|
5 |
|
|
|
6 |
from cloudbrain.modules.sources.openbci import OpenBCISource |
|
|
7 |
from cloudbrain.publishers.rabbitmq import PikaPublisher |
|
|
8 |
from cloudbrain.connectors.openbci import OpenBCIConnector, OpenBCISample |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
|
|
|
12 |
class MockSerial(object): |
|
|
13 |
""" |
|
|
14 |
Patch class serial.Serial |
|
|
15 |
""" |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
def __init__(self, port, baud): |
|
|
19 |
pass |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def write(self, text): |
|
|
23 |
pass |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def read(self): |
|
|
27 |
return 'mock' |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def close(self): |
|
|
31 |
pass |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
def inWaiting(self): |
|
|
35 |
return False |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
class MockHTTPResponse(object): |
|
|
39 |
|
|
|
40 |
status_code = 200 |
|
|
41 |
|
|
|
42 |
def json(self): |
|
|
43 |
return {'vhost': 'mock_vhost'} |
|
|
44 |
|
|
|
45 |
|
|
|
46 |
def mock_requests_get(vhost_info_url, verify): |
|
|
47 |
return MockHTTPResponse() |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
def mock_start(self, callback_functions): |
|
|
51 |
""" |
|
|
52 |
Patch method OpenBCIConnector.start |
|
|
53 |
|
|
|
54 |
:param callback_functions: (dict) |
|
|
55 |
E.g: {metric_0: callback_0, ..., metric_N: callback_N} |
|
|
56 |
where 'metric_X' is a string and 'callback_X' is a function that takes an |
|
|
57 |
Array as argument. |
|
|
58 |
""" |
|
|
59 |
|
|
|
60 |
packet_id = 0 |
|
|
61 |
channel_data = [i for i in range(8)] |
|
|
62 |
aux_data = [] |
|
|
63 |
timestamp = int(time.time() * 1000000) # in microseconds |
|
|
64 |
sample = OpenBCISample(packet_id, channel_data, aux_data, timestamp) |
|
|
65 |
for (metric, callback_function) in callback_functions.items(): |
|
|
66 |
callback_function(sample) |
|
|
67 |
|
|
|
68 |
|
|
|
69 |
|
|
|
70 |
class MockChannel(object): |
|
|
71 |
def __init__(self): |
|
|
72 |
self.published_count = 0 |
|
|
73 |
self.message_published = None |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
def exchange_declare(self, exchange, exchange_type): |
|
|
77 |
pass |
|
|
78 |
|
|
|
79 |
|
|
|
80 |
def basic_publish(self, exchange, routing_key, body, properties): |
|
|
81 |
self.published_count += 1 |
|
|
82 |
self.message_published = body |
|
|
83 |
print("Message count: %s" % self.published_count) |
|
|
84 |
print("Message body: %s" % body) |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
|
|
|
88 |
class MockBlockingConnection(object): |
|
|
89 |
""" |
|
|
90 |
Patch pika's BlockingConnection |
|
|
91 |
""" |
|
|
92 |
|
|
|
93 |
|
|
|
94 |
def __init__(self, params): |
|
|
95 |
pass |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
def close(self): |
|
|
99 |
pass |
|
|
100 |
|
|
|
101 |
|
|
|
102 |
def channel(self): |
|
|
103 |
return MockChannel() |
|
|
104 |
|
|
|
105 |
|
|
|
106 |
|
|
|
107 |
class OpenBCITest(unittest.TestCase): |
|
|
108 |
def setUp(self): |
|
|
109 |
self.device = 'openbci' |
|
|
110 |
self.user = "mock" |
|
|
111 |
|
|
|
112 |
# OpenBCI config |
|
|
113 |
self.port = None |
|
|
114 |
self.baud = 0 |
|
|
115 |
self.filter_data = False |
|
|
116 |
|
|
|
117 |
# metric info |
|
|
118 |
self.metric_name = 'eeg' |
|
|
119 |
self.num_channels = 8 |
|
|
120 |
self.buffer_size = 2 |
|
|
121 |
|
|
|
122 |
# rmq info |
|
|
123 |
self.rabbitmq_user = 'mock_user' |
|
|
124 |
self.rabbitmq_pwd = 'mock_pwd' |
|
|
125 |
|
|
|
126 |
self.base_routing_key = '%s:%s' % (self.user, self.device) |
|
|
127 |
|
|
|
128 |
self.message = {'timestamp': 100} |
|
|
129 |
for i in range(self.num_channels): |
|
|
130 |
self.message['channel_%s' % i] = i |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
def validate_start_method(self, sample): |
|
|
134 |
self.assertEqual(sample.channel_data, [i for i in range(self.num_channels)]) |
|
|
135 |
print("OpenBCI started: %s" % sample.channel_data) |
|
|
136 |
|
|
|
137 |
|
|
|
138 |
@patch('serial.Serial', MockSerial) |
|
|
139 |
@patch('cloudbrain.connectors.openbci.OpenBCIConnector.start', |
|
|
140 |
mock_start) |
|
|
141 |
def test_OpenBCIConnector(self): |
|
|
142 |
board = OpenBCIConnector() |
|
|
143 |
callbacks = {self.metric_name: self.validate_start_method} |
|
|
144 |
board.start(callbacks) |
|
|
145 |
|
|
|
146 |
|
|
|
147 |
@patch('requests.get', mock_requests_get) |
|
|
148 |
@patch('pika.BlockingConnection', MockBlockingConnection) |
|
|
149 |
def test_PikaPublisher(self): |
|
|
150 |
|
|
|
151 |
options = {"rabbitmq_user": self.rabbitmq_user, |
|
|
152 |
"rabbitmq_pwd": self.rabbitmq_pwd} |
|
|
153 |
|
|
|
154 |
publisher = PikaPublisher(self.base_routing_key, **options) |
|
|
155 |
publisher.connect() |
|
|
156 |
publisher.register(self.metric_name, self.num_channels, |
|
|
157 |
self.buffer_size) |
|
|
158 |
|
|
|
159 |
for metric_name in publisher.metrics_to_num_channels().keys(): |
|
|
160 |
routing_key = "%s:%s" % (self.base_routing_key, metric_name) |
|
|
161 |
self.assertEqual(publisher.channels[routing_key].published_count, 0) |
|
|
162 |
self.assertEqual(publisher.channels[routing_key].message_published, |
|
|
163 |
None) |
|
|
164 |
|
|
|
165 |
publisher.publish(metric_name, self.message) |
|
|
166 |
self.assertEqual(publisher.channels[routing_key].published_count, 0) |
|
|
167 |
|
|
|
168 |
publisher.publish(metric_name, self.message) |
|
|
169 |
self.assertEqual(publisher.channels[routing_key].published_count, 1) |
|
|
170 |
|
|
|
171 |
expected_message = [{ |
|
|
172 |
"channel_5": 5, "channel_4": 4, "channel_7": 7, |
|
|
173 |
"channel_6": 6, "channel_1": 1, "channel_0": 0, |
|
|
174 |
"channel_3": 3, "channel_2": 2, "timestamp": 100 |
|
|
175 |
}, { |
|
|
176 |
"channel_5": 5, "channel_4": 4, "channel_7": 7, |
|
|
177 |
"channel_6": 6, "channel_1": 1, "channel_0": 0, |
|
|
178 |
"channel_3": 3, "channel_2": 2, "timestamp": 100 |
|
|
179 |
}] |
|
|
180 |
|
|
|
181 |
published_message = json.loads( |
|
|
182 |
publisher.channels[routing_key].message_published) |
|
|
183 |
self.assertEqual(published_message, expected_message) |
|
|
184 |
|
|
|
185 |
|
|
|
186 |
@patch('requests.get', mock_requests_get) |
|
|
187 |
@patch('serial.Serial', MockSerial) |
|
|
188 |
@patch('pika.BlockingConnection', MockBlockingConnection) |
|
|
189 |
@patch('cloudbrain.connectors.openbci.OpenBCIConnector.start', |
|
|
190 |
mock_start) |
|
|
191 |
def test_OpenBCISource(self): |
|
|
192 |
options = {"rabbitmq_user": self.rabbitmq_user, |
|
|
193 |
"rabbitmq_pwd": self.rabbitmq_pwd} |
|
|
194 |
|
|
|
195 |
publisher = PikaPublisher(self.base_routing_key, **options) |
|
|
196 |
publisher.connect() |
|
|
197 |
publisher.register(self.metric_name, self.num_channels, |
|
|
198 |
self.buffer_size) |
|
|
199 |
|
|
|
200 |
publishers = [publisher] |
|
|
201 |
subscribers = [] |
|
|
202 |
source = OpenBCISource(subscribers=subscribers, |
|
|
203 |
publishers=publishers, |
|
|
204 |
port=self.port, |
|
|
205 |
baud=self.baud, |
|
|
206 |
filter_data=self.filter_data) |
|
|
207 |
source.start() |
|
|
208 |
|
|
|
209 |
pub = source.publishers[0] |
|
|
210 |
|
|
|
211 |
routing_key = "%s:%s" % (self.base_routing_key, self.metric_name) |
|
|
212 |
self.assertEqual(pub.channels[routing_key].published_count, 0) |
|
|
213 |
self.assertEqual(pub.channels[routing_key].message_published, |
|
|
214 |
None, "No messages should have been sent yet.") |