Diff of /tests/unit/fft_test.py [000000] .. [00c700]

Switch to side-by-side view

--- a
+++ b/tests/unit/fft_test.py
@@ -0,0 +1,203 @@
+import logging
+import numpy as np
+import unittest
+import random
+
+from cloudbrain.modules.transforms.fft import FrequencyBandTransformer
+
+_LOGGER = logging.getLogger(__name__)
+_LOGGER.level = logging.DEBUG
+_LOGGER.addHandler(logging.StreamHandler())
+
+
+def generate_sine_wave(number_points,
+                       sampling_frequency,
+                       alpha_amplitude,
+                       alpha_freq,
+                       beta_amplitude,
+                       beta_freq):
+    sample_spacing = 1.0 / sampling_frequency
+    x = np.linspace(start=0.0,
+                    stop=number_points * sample_spacing,
+                    num=number_points)
+
+    alpha = alpha_amplitude * np.sin(alpha_freq * 2.0 * np.pi * x)
+    beta = beta_amplitude * np.sin(beta_freq * 2.0 * np.pi * x)
+
+    y = alpha + beta + min(alpha_amplitude,
+                           beta_amplitude) / 2.0 * random.random()
+
+    return y
+
+
+
+def generate_mock_data(num_channels,
+                       number_points,
+                       sampling_frequency,
+                       buffer_size,
+                       alpha_amplitude,
+                       alpha_freq,
+                       beta_amplitude,
+                       beta_freq):
+    sine_wave = generate_sine_wave(number_points,
+                                   sampling_frequency,
+                                   alpha_amplitude,
+                                   alpha_freq,
+                                   beta_amplitude,
+                                   beta_freq)
+
+    buffers = []
+
+    # if number_points = 250 and buffer_size = 40, then num_buffers = 6 + 1
+    num_buffers = int(number_points / buffer_size)
+    if number_points % buffer_size != 0:
+        num_buffers += 1  # you need to account for the almost full buffer
+
+    t0 = 0
+    sample_spacing = 1.0 / sampling_frequency
+    for i in range(num_buffers):
+        if (i + 1) * buffer_size < len(sine_wave):
+            y_chunk = sine_wave[i * buffer_size:(i + 1) * buffer_size]
+        else:
+            y_chunk = sine_wave[i * buffer_size:]
+
+        buffer = []
+        for j in range(len(y_chunk)):
+            timestamp_in_s = t0 + (buffer_size * i + j) * sample_spacing
+            timestamp = int(timestamp_in_s * 1000)
+
+            datapoint = {'timestamp': timestamp}
+
+            for k in range(num_channels):
+                datapoint['channel_%s' % k] = y_chunk[j]
+
+            buffer.append(datapoint)
+
+        buffers.append(buffer)
+
+    return buffers
+
+
+
+def plot_cb_buffers(num_channels, cb_buffers):
+    maxi_buffer = []
+    for cb_buffer in cb_buffers:
+        maxi_buffer.extend(cb_buffer)
+    plot_cb_buffer(num_channels, maxi_buffer)
+
+
+
+def plot_cb_buffer(num_channels, cb_buffer):
+    import matplotlib.pyplot as plt
+    f, axarr = plt.subplots(num_channels)
+    for i in range(num_channels):
+        channel_name = 'channel_%s' % i
+        data_to_plot = []
+        for data in cb_buffer:
+            data_to_plot.append(data[channel_name])
+        axarr[i].plot(data_to_plot)
+        axarr[i].set_title(channel_name)
+    plt.show()
+
+
+
+def generate_frequency_bands(alpha_freq, beta_freq, frequency_band_size):
+    """
+    Make sure to choose the frequency band size carefully.
+    If it is too high, frequency bands will overlap.
+    If it's too low, the band might not be large enough to detect the frequency.
+    """
+
+    alpha_range = [alpha_freq - frequency_band_size / 2.0,
+                   alpha_freq + frequency_band_size / 2.0]
+    beta_range = [beta_freq - frequency_band_size / 2.0,
+                  beta_freq + frequency_band_size / 2.0]
+
+    frequency_bands = {'alpha': alpha_range, 'beta': beta_range}
+
+    return frequency_bands
+
+
+
+class FrequencyBandTranformerTest(unittest.TestCase):
+    def setUp(self):
+
+        self.plot_input_data = False
+
+        self.window_size = 250  # Also OK => 2 * 2.50. Or 3 * 250
+        self.sampling_freq = 250.0
+
+        self.buffer_size = 10
+
+        self.alpha_amplitude = 10.0
+        self.alpha_freq = 10.0
+
+        self.beta_amplitude = 5.0
+        self.beta_freq = 25.0
+
+        self.num_channels = 8
+
+        self.freq_band_size = 10.0
+
+        self.number_points = 250
+
+        self.cb_buffers = generate_mock_data(self.num_channels,
+                                             self.number_points,
+                                             self.sampling_freq,
+                                             self.buffer_size,
+                                             self.alpha_amplitude,
+                                             self.alpha_freq,
+                                             self.beta_amplitude,
+                                             self.beta_freq)
+
+
+    def test_cb_buffers(self):
+
+        if self.plot_input_data:
+            plot_cb_buffers(self.num_channels, self.cb_buffers)
+
+        num_buffers = self.number_points / self.buffer_size
+        if self.window_size % self.buffer_size != 0:
+            num_buffers += 1
+
+        self.assertEqual(len(self.cb_buffers), num_buffers)
+
+        self.assertEqual(len(self.cb_buffers[0]), self.buffer_size)
+        self.assertTrue('channel_0' in self.cb_buffers[0][0].keys())
+
+
+    def test_module(self):
+
+        self.frequency_bands = generate_frequency_bands(self.alpha_freq,
+                                                        self.beta_freq,
+                                                        self.freq_band_size)
+
+        self.subscribers = []
+        self.publishers = []
+
+        module = FrequencyBandTransformer(subscribers=self.subscribers,
+                                          publishers=self.publishers,
+                                          window_size=self.window_size,
+                                          sampling_frequency=self.sampling_freq,
+                                          frequency_bands=self.frequency_bands)
+
+        for cb_buffer in self.cb_buffers:
+            # where the real logic inside the subscriber takes place
+
+            bands = module._compute_fft(cb_buffer, self.num_channels)
+
+            if bands:
+                for i in range(self.num_channels):
+                    channel_name = 'channel_%s' % i
+                    alpha_estimated_ampl = bands['alpha'][channel_name]
+                    beta_estimated_ampl = bands['beta'][channel_name]
+
+                    ratio = self.beta_amplitude / self.alpha_amplitude
+                    estimated_ratio = beta_estimated_ampl / alpha_estimated_ampl
+
+                    _LOGGER.debug("Alpha: estimated=%s | actual=%s" %
+                                  (alpha_estimated_ampl, self.alpha_amplitude))
+
+                    _LOGGER.debug("Beta: estimated=%s | actual=%s" %
+                                  (beta_estimated_ampl, self.beta_amplitude))
+                    assert np.abs((estimated_ratio - ratio) / ratio) < 0.01