|
a |
|
b/test/unit_tests/test_util.py |
|
|
1 |
# Authors: Hubert Banville <hubert.jbanville@gmail.com> |
|
|
2 |
# Bruno Aristimunha <b.aristimunha@gmail.com> |
|
|
3 |
# License: BSD-3 |
|
|
4 |
|
|
|
5 |
import os |
|
|
6 |
from unittest import mock |
|
|
7 |
|
|
|
8 |
import mne |
|
|
9 |
import numpy as np |
|
|
10 |
import h5py |
|
|
11 |
import pytest |
|
|
12 |
import torch |
|
|
13 |
import tempfile |
|
|
14 |
|
|
|
15 |
from sklearn.utils import check_random_state |
|
|
16 |
from numpy.testing import assert_array_equal, assert_allclose |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
from braindecode.util import _cov_and_var_to_corr, _cov_to_corr, \ |
|
|
20 |
corr, create_mne_dummy_raw, \ |
|
|
21 |
read_all_file_names, set_random_seeds, th_to_np, cov, np_to_th, \ |
|
|
22 |
get_balanced_batches |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
def test_create_mne_dummy_raw(tmp_path): |
|
|
26 |
n_channels, n_times, sfreq = 2, 10000, 100 |
|
|
27 |
raw, fnames = create_mne_dummy_raw( |
|
|
28 |
n_channels, n_times, sfreq, savedir=tmp_path, save_format=["fif", "hdf5"] |
|
|
29 |
) |
|
|
30 |
|
|
|
31 |
assert isinstance(raw, mne.io.RawArray) |
|
|
32 |
assert len(raw.ch_names) == n_channels |
|
|
33 |
assert raw.n_times == n_times |
|
|
34 |
assert raw.info["sfreq"] == sfreq |
|
|
35 |
assert isinstance(fnames, dict) |
|
|
36 |
assert os.path.isfile(fnames["fif"]) |
|
|
37 |
assert os.path.isfile(fnames["hdf5"]) |
|
|
38 |
|
|
|
39 |
_ = mne.io.read_raw_fif(fnames["fif"], preload=False, verbose=None) |
|
|
40 |
with h5py.File(fnames["hdf5"], "r") as hf: |
|
|
41 |
_ = np.array(hf["fake_raw"]) |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def test_set_random_seeds_raise_value_error(): |
|
|
45 |
with pytest.raises( |
|
|
46 |
ValueError, match="cudnn_benchmark expected to be bool or None, got 'abc'" |
|
|
47 |
): |
|
|
48 |
set_random_seeds(100, True, "abc") |
|
|
49 |
|
|
|
50 |
|
|
|
51 |
def test_set_random_seeds_warning(): |
|
|
52 |
torch.backends.cudnn.benchmark = True |
|
|
53 |
with pytest.warns( |
|
|
54 |
UserWarning, |
|
|
55 |
match="torch.backends.cudnn.benchmark was set to True which may results in " |
|
|
56 |
"lack of reproducibility. In some cases to ensure reproducibility you " |
|
|
57 |
"may need to set torch.backends.cudnn.benchmark to False.", |
|
|
58 |
): |
|
|
59 |
set_random_seeds(100, True) |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
def test_set_random_seeds_with_valid_cudnn_benchmark(): |
|
|
63 |
with mock.patch("torch.backends.cudnn") as mock_cudnn: |
|
|
64 |
# Test with cudnn_benchmark = True |
|
|
65 |
set_random_seeds(42, cuda=True, cudnn_benchmark=True) |
|
|
66 |
assert mock_cudnn.benchmark is True |
|
|
67 |
|
|
|
68 |
# Test with cudnn_benchmark = False |
|
|
69 |
set_random_seeds(42, cuda=True, |
|
|
70 |
cudnn_benchmark=False) |
|
|
71 |
assert mock_cudnn.benchmark is False |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
def test_set_random_seeds_with_invalid_cudnn_benchmark(): |
|
|
75 |
with pytest.raises(ValueError): |
|
|
76 |
set_random_seeds(42, cuda=True, |
|
|
77 |
cudnn_benchmark='invalid_type') |
|
|
78 |
|
|
|
79 |
|
|
|
80 |
def test_th_to_np_data_preservation(): |
|
|
81 |
# Test with different data types |
|
|
82 |
for dtype in [torch.float32, torch.int32]: |
|
|
83 |
tensor = torch.tensor([1, 2, 3], dtype=dtype) |
|
|
84 |
np_array = th_to_np(tensor) |
|
|
85 |
assert np_array.dtype == tensor.numpy().dtype |
|
|
86 |
# Corrected attribute access |
|
|
87 |
assert_array_equal(np_array, tensor.numpy()) |
|
|
88 |
|
|
|
89 |
|
|
|
90 |
def test_th_to_np_on_cpu(): |
|
|
91 |
# Create a tensor on CPU |
|
|
92 |
cpu_tensor = torch.tensor([1, 2, 3], dtype=torch.float32) |
|
|
93 |
np_array = th_to_np(cpu_tensor) |
|
|
94 |
# Check the type and data of the numpy array |
|
|
95 |
assert isinstance(np_array, np.ndarray) |
|
|
96 |
assert np_array.dtype == cpu_tensor.numpy().dtype |
|
|
97 |
# Correct way to check dtype |
|
|
98 |
assert_array_equal(np_array, cpu_tensor.numpy()) |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
def test_cov_basic(): |
|
|
102 |
# Create two simple identical arrays |
|
|
103 |
a = np.array([[1, 2, 3], [4, 5, 6]]) |
|
|
104 |
b = np.array([[1, 2, 3], [4, 5, 6]]) |
|
|
105 |
expected_cov = np.array([[1, 1], [1, 1]]) # Calculated expected covariance |
|
|
106 |
computed_cov = cov(a, b) |
|
|
107 |
assert_allclose(computed_cov, expected_cov, rtol=1e-5) |
|
|
108 |
|
|
|
109 |
|
|
|
110 |
def test_cov_dimension_mismatch(): |
|
|
111 |
# Arrays with mismatched sample size should raise an error |
|
|
112 |
a = np.array([[1, 2], [3, 4]]) |
|
|
113 |
b = np.array([[1, 2, 3]]) |
|
|
114 |
with pytest.raises(ValueError): |
|
|
115 |
cov(a, b) |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
def test_np_to_th_basic_conversion(): |
|
|
119 |
# Convert a simple list to tensor |
|
|
120 |
data = [1, 2, 3] |
|
|
121 |
tensor = np_to_th(data) |
|
|
122 |
assert torch.is_tensor(tensor) |
|
|
123 |
assert_array_equal(tensor.numpy(), np.array(data)) |
|
|
124 |
|
|
|
125 |
|
|
|
126 |
def test_np_to_th_dtype_conversion(): |
|
|
127 |
# Convert and specify dtype |
|
|
128 |
data = [1.0, 2.0, 3.0] |
|
|
129 |
tensor = np_to_th(data, dtype=np.float32) |
|
|
130 |
assert tensor.dtype == torch.float32 |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
def test_np_to_th_requires_grad(): |
|
|
134 |
# Check requires_grad attribute |
|
|
135 |
data = np.array([1.0, 2.0, 3.0]) |
|
|
136 |
tensor = np_to_th(data, requires_grad=True) |
|
|
137 |
assert tensor.requires_grad is True |
|
|
138 |
|
|
|
139 |
|
|
|
140 |
@pytest.mark.skipif(not torch.cuda.is_available(), |
|
|
141 |
reason="Requires CUDA support") |
|
|
142 |
def test_np_to_th_pin_memory(): |
|
|
143 |
# Create a numpy array |
|
|
144 |
data = np.array([1, 2, 3]) |
|
|
145 |
|
|
|
146 |
# Convert the numpy array to a tensor with pin_memory=True |
|
|
147 |
tensor = np_to_th(data, pin_memory=True) |
|
|
148 |
# Check if the tensor is pinned in memory |
|
|
149 |
assert tensor.is_pinned() is True |
|
|
150 |
# Convert the numpy array to a tensor with pin_memory=False |
|
|
151 |
tensor = np_to_th(data, pin_memory=False) |
|
|
152 |
# Check if the tensor is not pinned in memory |
|
|
153 |
assert tensor.is_pinned() is False |
|
|
154 |
|
|
|
155 |
def test_np_to_th_requires_grad_unsupported_dtype(): |
|
|
156 |
# Attempt to set requires_grad on an unsupported dtype (integers) |
|
|
157 |
data = np.array([1, 2, 3]) |
|
|
158 |
with pytest.raises(RuntimeError, |
|
|
159 |
match="Only Tensors of floating point and complex dtype can require gradients"): |
|
|
160 |
np_to_th(data, requires_grad=True) |
|
|
161 |
|
|
|
162 |
|
|
|
163 |
def test_np_to_th_tensor_options(): |
|
|
164 |
# Additional tensor options like device |
|
|
165 |
data = [1, 2, 3] |
|
|
166 |
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
167 |
tensor = np_to_th(data, device=device) |
|
|
168 |
assert tensor.device.type == device |
|
|
169 |
|
|
|
170 |
|
|
|
171 |
def test_np_to_th_single_number_conversion(): |
|
|
172 |
# Single number conversion to tensor |
|
|
173 |
data = 42 |
|
|
174 |
tensor = np_to_th(data) |
|
|
175 |
assert torch.is_tensor(tensor) |
|
|
176 |
assert tensor.item() == 42 # Check if the value is correct |
|
|
177 |
|
|
|
178 |
|
|
|
179 |
def test_cov_and_var_to_corr_zero_variance(): |
|
|
180 |
# Scenario with zero variance, expecting result to handle divide by zero |
|
|
181 |
this_cov = np.array([[1, 0], |
|
|
182 |
[0, 1]]) |
|
|
183 |
var_a = np.array([0, 1]) # Zero variance in the first variable |
|
|
184 |
var_b = np.array([1, 0]) # Zero variance in the second variable |
|
|
185 |
calculated_corr = _cov_and_var_to_corr(this_cov, var_a, var_b) |
|
|
186 |
|
|
|
187 |
# Expected correlation matrix |
|
|
188 |
expected_corr = np.array([[np.inf, np.nan], |
|
|
189 |
[0, np.inf]]) |
|
|
190 |
assert_array_equal(calculated_corr, expected_corr) |
|
|
191 |
|
|
|
192 |
|
|
|
193 |
def test_cov_and_var_to_corr_single_element(): |
|
|
194 |
# Testing with single-element arrays |
|
|
195 |
this_cov = np.array([[1]]) |
|
|
196 |
var_a = np.array([1]) |
|
|
197 |
var_b = np.array([1]) |
|
|
198 |
expected_corr = np.array([[1]]) |
|
|
199 |
calculated_corr = _cov_and_var_to_corr(this_cov, var_a, var_b) |
|
|
200 |
assert_array_equal(calculated_corr, expected_corr) |
|
|
201 |
|
|
|
202 |
|
|
|
203 |
|
|
|
204 |
def test_cov_to_corr_unbiased(): |
|
|
205 |
# Create datasets a and b with known covariance and variance characteristics |
|
|
206 |
a = np.array([[1, 2, 3, 4], |
|
|
207 |
[2, 3, 4, 5]]) |
|
|
208 |
b = np.array([[1, 3, 5, 7], |
|
|
209 |
[5, 6, 7, 8]]) |
|
|
210 |
# Covariance between the features of a and b |
|
|
211 |
# Calculating covariance manually for known values |
|
|
212 |
demeaned_a = a - np.mean(a, axis=1, keepdims=True) |
|
|
213 |
demeaned_b = b - np.mean(b, axis=1, keepdims=True) |
|
|
214 |
this_cov = np.dot(demeaned_a, demeaned_b.T) / (b.shape[1] - 1) |
|
|
215 |
|
|
|
216 |
# Compute expected correlation using standard formulas for correlation |
|
|
217 |
var_a = np.var(a, axis=1, ddof=1) |
|
|
218 |
var_b = np.var(b, axis=1, ddof=1) |
|
|
219 |
expected_divisor = np.outer(np.sqrt(var_a), np.sqrt(var_b)) |
|
|
220 |
expected_corr = this_cov / expected_divisor |
|
|
221 |
|
|
|
222 |
# Compute correlation using the function |
|
|
223 |
calculated_corr = _cov_to_corr(this_cov, a, b) |
|
|
224 |
|
|
|
225 |
# Assert that the calculated correlation matches the expected correlation |
|
|
226 |
assert_allclose(calculated_corr, expected_corr, rtol=1e-5) |
|
|
227 |
|
|
|
228 |
|
|
|
229 |
def test_balanced_batches_basic(): |
|
|
230 |
n_trials = 100 |
|
|
231 |
seed = 42 |
|
|
232 |
rng = check_random_state(seed) |
|
|
233 |
n_batches = 10 |
|
|
234 |
batches = get_balanced_batches(n_trials, rng, shuffle=False, |
|
|
235 |
n_batches=n_batches) |
|
|
236 |
|
|
|
237 |
# Check correct number of batches |
|
|
238 |
assert len(batches) == n_batches |
|
|
239 |
|
|
|
240 |
# Check balanced batch sizes |
|
|
241 |
all_batch_sizes = [len(batch) for batch in batches] |
|
|
242 |
max_size = max(all_batch_sizes) |
|
|
243 |
min_size = min(all_batch_sizes) |
|
|
244 |
assert max_size - min_size <= 1 |
|
|
245 |
|
|
|
246 |
# Check if all indices are unique and accounted for |
|
|
247 |
all_indices = np.concatenate(batches) |
|
|
248 |
assert np.array_equal(np.sort(all_indices), np.arange(n_trials)) |
|
|
249 |
|
|
|
250 |
|
|
|
251 |
def test_balanced_batches_with_batch_size(): |
|
|
252 |
n_trials = 105 |
|
|
253 |
seed = 42 |
|
|
254 |
rng = check_random_state(seed) |
|
|
255 |
batch_size = 20 |
|
|
256 |
batches = get_balanced_batches(n_trials, rng, shuffle=False, |
|
|
257 |
batch_size=batch_size) |
|
|
258 |
|
|
|
259 |
# Check the modified batch size condition |
|
|
260 |
expected_n_batches = int(np.round(n_trials / float(batch_size))) |
|
|
261 |
assert len(batches) == expected_n_batches |
|
|
262 |
|
|
|
263 |
# Checking the total number of indices |
|
|
264 |
all_indices = np.concatenate(batches) |
|
|
265 |
assert len(all_indices) == n_trials |
|
|
266 |
|
|
|
267 |
|
|
|
268 |
def test_balanced_batches_shuffle(): |
|
|
269 |
n_trials = 50 |
|
|
270 |
seed = 42 |
|
|
271 |
rng = check_random_state(seed) |
|
|
272 |
batches_no_shuffle = get_balanced_batches(n_trials, rng, shuffle=False, |
|
|
273 |
batch_size=10) |
|
|
274 |
rng = check_random_state(seed) |
|
|
275 |
batches_with_shuffle = get_balanced_batches(n_trials, rng, shuffle=True, |
|
|
276 |
batch_size=10) |
|
|
277 |
|
|
|
278 |
# Check that shuffling changes the order of indices |
|
|
279 |
assert not np.array_equal(np.concatenate(batches_no_shuffle), |
|
|
280 |
np.concatenate(batches_with_shuffle)) |
|
|
281 |
|
|
|
282 |
|
|
|
283 |
def test_corr_correlation_computation(): |
|
|
284 |
# Create two 2D arrays with known correlation |
|
|
285 |
a = np.array([[1, 2, 3], [4, 5, 6]]) |
|
|
286 |
b = np.array([[1, 2, 3], [4, 5, 6]]) |
|
|
287 |
|
|
|
288 |
# Call the corr function |
|
|
289 |
corr_result = corr(a, b) |
|
|
290 |
|
|
|
291 |
# Compute the known correlation |
|
|
292 |
known_corr = np.array([np.corrcoef(a[i], b[i]) for i in range(a.shape[0])]) |
|
|
293 |
|
|
|
294 |
# Extract the correlation computation from the corr function |
|
|
295 |
computed_corr = _cov_to_corr(cov(a, b), a, b) |
|
|
296 |
|
|
|
297 |
# Assert that the computed correlation matches the known correlation |
|
|
298 |
assert np.allclose(computed_corr, known_corr) |
|
|
299 |
assert np.allclose(corr_result, computed_corr) |
|
|
300 |
|
|
|
301 |
|
|
|
302 |
def test_get_balanced_batches_zero_batches(): |
|
|
303 |
# Create a scenario where n_batches is 0 |
|
|
304 |
n_trials = 10 |
|
|
305 |
rng = check_random_state(0) |
|
|
306 |
shuffle = False |
|
|
307 |
n_batches = 0 |
|
|
308 |
batch_size = None |
|
|
309 |
|
|
|
310 |
# Call the get_balanced_batches function |
|
|
311 |
batches = get_balanced_batches(n_trials, rng, shuffle, n_batches, batch_size) |
|
|
312 |
|
|
|
313 |
# Check if the function returns a single batch with all trials |
|
|
314 |
assert len(batches) == 1 |
|
|
315 |
assert len(batches[0]) == n_trials |
|
|
316 |
|
|
|
317 |
|
|
|
318 |
def test_get_balanced_batches_i_batch_less_than_n_batches_with_extra_trial(): |
|
|
319 |
# Create a scenario where i_batch < n_batches_with_extra_trial |
|
|
320 |
n_trials = 10 |
|
|
321 |
rng = check_random_state(0) |
|
|
322 |
shuffle = False |
|
|
323 |
n_batches = 6 |
|
|
324 |
batch_size = None |
|
|
325 |
|
|
|
326 |
# Call the get_balanced_batches function |
|
|
327 |
batches = get_balanced_batches(n_trials, rng, shuffle, n_batches, batch_size) |
|
|
328 |
|
|
|
329 |
# Check if the first batch has one more trial than the last batch |
|
|
330 |
assert len(batches[0]) > len(batches[-1]) |
|
|
331 |
|
|
|
332 |
|
|
|
333 |
def test_read_all_file_names(): |
|
|
334 |
# Create a temporary directory |
|
|
335 |
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
336 |
temp_files = [] |
|
|
337 |
try: |
|
|
338 |
# Create some temporary files with .txt extension |
|
|
339 |
for i in range(5): |
|
|
340 |
temp_file = os.path.join(tmpdir, f'temp{i}.txt') |
|
|
341 |
with open(temp_file, 'w') as f: |
|
|
342 |
f.write('This is a temporary file.') |
|
|
343 |
temp_files.append(temp_file) |
|
|
344 |
|
|
|
345 |
# Call the read_all_file_names function |
|
|
346 |
file_paths = read_all_file_names(tmpdir, '.txt') |
|
|
347 |
|
|
|
348 |
# Check if the function found all the temporary files |
|
|
349 |
assert len(file_paths) == 5 |
|
|
350 |
|
|
|
351 |
# Check if the paths returned by the function are correct |
|
|
352 |
for i in range(5): |
|
|
353 |
assert os.path.join(tmpdir, f'temp{i}.txt') in file_paths |
|
|
354 |
finally: |
|
|
355 |
# Delete the temporary files |
|
|
356 |
for temp_file in temp_files: |
|
|
357 |
os.remove(temp_file) |
|
|
358 |
|
|
|
359 |
|
|
|
360 |
def test_read_all_file_names_error(): |
|
|
361 |
with pytest.raises(AssertionError): |
|
|
362 |
# Call the read_all_file_names function with a non-existent directory |
|
|
363 |
read_all_file_names('non_existent_dir', '.txt') |
|
|
364 |
|
|
|
365 |
|
|
|
366 |
def test_read_all_files_not_extension(): |
|
|
367 |
with pytest.raises(AssertionError): |
|
|
368 |
# Call the read_all_file_names function with a non-existent directory |
|
|
369 |
read_all_file_names('non_existent_dir', 'txt') |