--- a +++ b/tests/unit/fft_test.py @@ -0,0 +1,203 @@ +import logging +import numpy as np +import unittest +import random + +from cloudbrain.modules.transforms.fft import FrequencyBandTransformer + +_LOGGER = logging.getLogger(__name__) +_LOGGER.level = logging.DEBUG +_LOGGER.addHandler(logging.StreamHandler()) + + +def generate_sine_wave(number_points, + sampling_frequency, + alpha_amplitude, + alpha_freq, + beta_amplitude, + beta_freq): + sample_spacing = 1.0 / sampling_frequency + x = np.linspace(start=0.0, + stop=number_points * sample_spacing, + num=number_points) + + alpha = alpha_amplitude * np.sin(alpha_freq * 2.0 * np.pi * x) + beta = beta_amplitude * np.sin(beta_freq * 2.0 * np.pi * x) + + y = alpha + beta + min(alpha_amplitude, + beta_amplitude) / 2.0 * random.random() + + return y + + + +def generate_mock_data(num_channels, + number_points, + sampling_frequency, + buffer_size, + alpha_amplitude, + alpha_freq, + beta_amplitude, + beta_freq): + sine_wave = generate_sine_wave(number_points, + sampling_frequency, + alpha_amplitude, + alpha_freq, + beta_amplitude, + beta_freq) + + buffers = [] + + # if number_points = 250 and buffer_size = 40, then num_buffers = 6 + 1 + num_buffers = int(number_points / buffer_size) + if number_points % buffer_size != 0: + num_buffers += 1 # you need to account for the almost full buffer + + t0 = 0 + sample_spacing = 1.0 / sampling_frequency + for i in range(num_buffers): + if (i + 1) * buffer_size < len(sine_wave): + y_chunk = sine_wave[i * buffer_size:(i + 1) * buffer_size] + else: + y_chunk = sine_wave[i * buffer_size:] + + buffer = [] + for j in range(len(y_chunk)): + timestamp_in_s = t0 + (buffer_size * i + j) * sample_spacing + timestamp = int(timestamp_in_s * 1000) + + datapoint = {'timestamp': timestamp} + + for k in range(num_channels): + datapoint['channel_%s' % k] = y_chunk[j] + + buffer.append(datapoint) + + buffers.append(buffer) + + return buffers + + + +def plot_cb_buffers(num_channels, cb_buffers): + maxi_buffer = [] + for cb_buffer in cb_buffers: + maxi_buffer.extend(cb_buffer) + plot_cb_buffer(num_channels, maxi_buffer) + + + +def plot_cb_buffer(num_channels, cb_buffer): + import matplotlib.pyplot as plt + f, axarr = plt.subplots(num_channels) + for i in range(num_channels): + channel_name = 'channel_%s' % i + data_to_plot = [] + for data in cb_buffer: + data_to_plot.append(data[channel_name]) + axarr[i].plot(data_to_plot) + axarr[i].set_title(channel_name) + plt.show() + + + +def generate_frequency_bands(alpha_freq, beta_freq, frequency_band_size): + """ + Make sure to choose the frequency band size carefully. + If it is too high, frequency bands will overlap. + If it's too low, the band might not be large enough to detect the frequency. + """ + + alpha_range = [alpha_freq - frequency_band_size / 2.0, + alpha_freq + frequency_band_size / 2.0] + beta_range = [beta_freq - frequency_band_size / 2.0, + beta_freq + frequency_band_size / 2.0] + + frequency_bands = {'alpha': alpha_range, 'beta': beta_range} + + return frequency_bands + + + +class FrequencyBandTranformerTest(unittest.TestCase): + def setUp(self): + + self.plot_input_data = False + + self.window_size = 250 # Also OK => 2 * 2.50. Or 3 * 250 + self.sampling_freq = 250.0 + + self.buffer_size = 10 + + self.alpha_amplitude = 10.0 + self.alpha_freq = 10.0 + + self.beta_amplitude = 5.0 + self.beta_freq = 25.0 + + self.num_channels = 8 + + self.freq_band_size = 10.0 + + self.number_points = 250 + + self.cb_buffers = generate_mock_data(self.num_channels, + self.number_points, + self.sampling_freq, + self.buffer_size, + self.alpha_amplitude, + self.alpha_freq, + self.beta_amplitude, + self.beta_freq) + + + def test_cb_buffers(self): + + if self.plot_input_data: + plot_cb_buffers(self.num_channels, self.cb_buffers) + + num_buffers = self.number_points / self.buffer_size + if self.window_size % self.buffer_size != 0: + num_buffers += 1 + + self.assertEqual(len(self.cb_buffers), num_buffers) + + self.assertEqual(len(self.cb_buffers[0]), self.buffer_size) + self.assertTrue('channel_0' in self.cb_buffers[0][0].keys()) + + + def test_module(self): + + self.frequency_bands = generate_frequency_bands(self.alpha_freq, + self.beta_freq, + self.freq_band_size) + + self.subscribers = [] + self.publishers = [] + + module = FrequencyBandTransformer(subscribers=self.subscribers, + publishers=self.publishers, + window_size=self.window_size, + sampling_frequency=self.sampling_freq, + frequency_bands=self.frequency_bands) + + for cb_buffer in self.cb_buffers: + # where the real logic inside the subscriber takes place + + bands = module._compute_fft(cb_buffer, self.num_channels) + + if bands: + for i in range(self.num_channels): + channel_name = 'channel_%s' % i + alpha_estimated_ampl = bands['alpha'][channel_name] + beta_estimated_ampl = bands['beta'][channel_name] + + ratio = self.beta_amplitude / self.alpha_amplitude + estimated_ratio = beta_estimated_ampl / alpha_estimated_ampl + + _LOGGER.debug("Alpha: estimated=%s | actual=%s" % + (alpha_estimated_ampl, self.alpha_amplitude)) + + _LOGGER.debug("Beta: estimated=%s | actual=%s" % + (beta_estimated_ampl, self.beta_amplitude)) + assert np.abs((estimated_ratio - ratio) / ratio) < 0.01