Switch to unified view

a b/test/unit_tests/test_util.py
1
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
2
#          Bruno Aristimunha <b.aristimunha@gmail.com>
3
# License: BSD-3
4
5
import os
6
from unittest import mock
7
8
import mne
9
import numpy as np
10
import h5py
11
import pytest
12
import torch
13
import tempfile
14
15
from sklearn.utils import check_random_state
16
from numpy.testing import assert_array_equal, assert_allclose
17
18
19
from braindecode.util import _cov_and_var_to_corr, _cov_to_corr, \
20
    corr, create_mne_dummy_raw, \
21
    read_all_file_names, set_random_seeds, th_to_np, cov, np_to_th, \
22
    get_balanced_batches
23
24
25
def test_create_mne_dummy_raw(tmp_path):
26
    n_channels, n_times, sfreq = 2, 10000, 100
27
    raw, fnames = create_mne_dummy_raw(
28
        n_channels, n_times, sfreq, savedir=tmp_path, save_format=["fif", "hdf5"]
29
    )
30
31
    assert isinstance(raw, mne.io.RawArray)
32
    assert len(raw.ch_names) == n_channels
33
    assert raw.n_times == n_times
34
    assert raw.info["sfreq"] == sfreq
35
    assert isinstance(fnames, dict)
36
    assert os.path.isfile(fnames["fif"])
37
    assert os.path.isfile(fnames["hdf5"])
38
39
    _ = mne.io.read_raw_fif(fnames["fif"], preload=False, verbose=None)
40
    with h5py.File(fnames["hdf5"], "r") as hf:
41
        _ = np.array(hf["fake_raw"])
42
43
44
def test_set_random_seeds_raise_value_error():
45
    with pytest.raises(
46
        ValueError, match="cudnn_benchmark expected to be bool or None, got 'abc'"
47
    ):
48
        set_random_seeds(100, True, "abc")
49
50
51
def test_set_random_seeds_warning():
52
    torch.backends.cudnn.benchmark = True
53
    with pytest.warns(
54
        UserWarning,
55
        match="torch.backends.cudnn.benchmark was set to True which may results in "
56
        "lack of reproducibility. In some cases to ensure reproducibility you "
57
        "may need to set torch.backends.cudnn.benchmark to False.",
58
    ):
59
        set_random_seeds(100, True)
60
61
62
def test_set_random_seeds_with_valid_cudnn_benchmark():
63
    with mock.patch("torch.backends.cudnn") as mock_cudnn:
64
        # Test with cudnn_benchmark = True
65
        set_random_seeds(42, cuda=True, cudnn_benchmark=True)
66
        assert mock_cudnn.benchmark is True
67
68
        # Test with cudnn_benchmark = False
69
        set_random_seeds(42, cuda=True,
70
                         cudnn_benchmark=False)
71
        assert mock_cudnn.benchmark is False
72
73
74
def test_set_random_seeds_with_invalid_cudnn_benchmark():
75
    with pytest.raises(ValueError):
76
        set_random_seeds(42, cuda=True,
77
                         cudnn_benchmark='invalid_type')
78
79
80
def test_th_to_np_data_preservation():
81
    # Test with different data types
82
    for dtype in [torch.float32, torch.int32]:
83
        tensor = torch.tensor([1, 2, 3], dtype=dtype)
84
        np_array = th_to_np(tensor)
85
        assert np_array.dtype == tensor.numpy().dtype
86
        # Corrected attribute access
87
        assert_array_equal(np_array, tensor.numpy())
88
89
90
def test_th_to_np_on_cpu():
91
    # Create a tensor on CPU
92
    cpu_tensor = torch.tensor([1, 2, 3], dtype=torch.float32)
93
    np_array = th_to_np(cpu_tensor)
94
    # Check the type and data of the numpy array
95
    assert isinstance(np_array, np.ndarray)
96
    assert np_array.dtype == cpu_tensor.numpy().dtype
97
    # Correct way to check dtype
98
    assert_array_equal(np_array, cpu_tensor.numpy())
99
100
101
def test_cov_basic():
102
    # Create two simple identical arrays
103
    a = np.array([[1, 2, 3], [4, 5, 6]])
104
    b = np.array([[1, 2, 3], [4, 5, 6]])
105
    expected_cov = np.array([[1, 1], [1, 1]])  # Calculated expected covariance
106
    computed_cov = cov(a, b)
107
    assert_allclose(computed_cov, expected_cov, rtol=1e-5)
108
109
110
def test_cov_dimension_mismatch():
111
    # Arrays with mismatched sample size should raise an error
112
    a = np.array([[1, 2], [3, 4]])
113
    b = np.array([[1, 2, 3]])
114
    with pytest.raises(ValueError):
115
        cov(a, b)
116
117
118
def test_np_to_th_basic_conversion():
119
    # Convert a simple list to tensor
120
    data = [1, 2, 3]
121
    tensor = np_to_th(data)
122
    assert torch.is_tensor(tensor)
123
    assert_array_equal(tensor.numpy(), np.array(data))
124
125
126
def test_np_to_th_dtype_conversion():
127
    # Convert and specify dtype
128
    data = [1.0, 2.0, 3.0]
129
    tensor = np_to_th(data, dtype=np.float32)
130
    assert tensor.dtype == torch.float32
131
132
133
def test_np_to_th_requires_grad():
134
    # Check requires_grad attribute
135
    data = np.array([1.0, 2.0, 3.0])
136
    tensor = np_to_th(data, requires_grad=True)
137
    assert tensor.requires_grad is True
138
139
140
@pytest.mark.skipif(not torch.cuda.is_available(),
141
                    reason="Requires CUDA support")
142
def test_np_to_th_pin_memory():
143
    # Create a numpy array
144
    data = np.array([1, 2, 3])
145
146
    # Convert the numpy array to a tensor with pin_memory=True
147
    tensor = np_to_th(data, pin_memory=True)
148
    # Check if the tensor is pinned in memory
149
    assert tensor.is_pinned() is True
150
    # Convert the numpy array to a tensor with pin_memory=False
151
    tensor = np_to_th(data, pin_memory=False)
152
    # Check if the tensor is not pinned in memory
153
    assert tensor.is_pinned() is False
154
155
def test_np_to_th_requires_grad_unsupported_dtype():
156
    # Attempt to set requires_grad on an unsupported dtype (integers)
157
    data = np.array([1, 2, 3])
158
    with pytest.raises(RuntimeError,
159
                       match="Only Tensors of floating point and complex dtype can require gradients"):
160
        np_to_th(data, requires_grad=True)
161
162
163
def test_np_to_th_tensor_options():
164
    # Additional tensor options like device
165
    data = [1, 2, 3]
166
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
167
    tensor = np_to_th(data, device=device)
168
    assert tensor.device.type == device
169
170
171
def test_np_to_th_single_number_conversion():
172
    # Single number conversion to tensor
173
    data = 42
174
    tensor = np_to_th(data)
175
    assert torch.is_tensor(tensor)
176
    assert tensor.item() == 42  # Check if the value is correct
177
178
179
def test_cov_and_var_to_corr_zero_variance():
180
    # Scenario with zero variance, expecting result to handle divide by zero
181
    this_cov = np.array([[1, 0],
182
                         [0, 1]])
183
    var_a = np.array([0, 1])  # Zero variance in the first variable
184
    var_b = np.array([1, 0])  # Zero variance in the second variable
185
    calculated_corr = _cov_and_var_to_corr(this_cov, var_a, var_b)
186
187
    # Expected correlation matrix
188
    expected_corr = np.array([[np.inf, np.nan],
189
                               [0, np.inf]])
190
    assert_array_equal(calculated_corr, expected_corr)
191
192
193
def test_cov_and_var_to_corr_single_element():
194
    # Testing with single-element arrays
195
    this_cov = np.array([[1]])
196
    var_a = np.array([1])
197
    var_b = np.array([1])
198
    expected_corr = np.array([[1]])
199
    calculated_corr = _cov_and_var_to_corr(this_cov, var_a, var_b)
200
    assert_array_equal(calculated_corr, expected_corr)
201
202
203
204
def test_cov_to_corr_unbiased():
205
    # Create datasets a and b with known covariance and variance characteristics
206
    a = np.array([[1, 2, 3, 4],
207
                  [2, 3, 4, 5]])
208
    b = np.array([[1, 3, 5, 7],
209
                  [5, 6, 7, 8]])
210
    # Covariance between the features of a and b
211
    # Calculating covariance manually for known values
212
    demeaned_a = a - np.mean(a, axis=1, keepdims=True)
213
    demeaned_b = b - np.mean(b, axis=1, keepdims=True)
214
    this_cov = np.dot(demeaned_a, demeaned_b.T) / (b.shape[1] - 1)
215
216
    # Compute expected correlation using standard formulas for correlation
217
    var_a = np.var(a, axis=1, ddof=1)
218
    var_b = np.var(b, axis=1, ddof=1)
219
    expected_divisor = np.outer(np.sqrt(var_a), np.sqrt(var_b))
220
    expected_corr = this_cov / expected_divisor
221
222
    # Compute correlation using the function
223
    calculated_corr = _cov_to_corr(this_cov, a, b)
224
225
    # Assert that the calculated correlation matches the expected correlation
226
    assert_allclose(calculated_corr, expected_corr, rtol=1e-5)
227
228
229
def test_balanced_batches_basic():
230
    n_trials = 100
231
    seed = 42
232
    rng = check_random_state(seed)
233
    n_batches = 10
234
    batches = get_balanced_batches(n_trials, rng, shuffle=False,
235
                                   n_batches=n_batches)
236
237
    # Check correct number of batches
238
    assert len(batches) == n_batches
239
240
    # Check balanced batch sizes
241
    all_batch_sizes = [len(batch) for batch in batches]
242
    max_size = max(all_batch_sizes)
243
    min_size = min(all_batch_sizes)
244
    assert max_size - min_size <= 1
245
246
    # Check if all indices are unique and accounted for
247
    all_indices = np.concatenate(batches)
248
    assert np.array_equal(np.sort(all_indices), np.arange(n_trials))
249
250
251
def test_balanced_batches_with_batch_size():
252
    n_trials = 105
253
    seed = 42
254
    rng = check_random_state(seed)
255
    batch_size = 20
256
    batches = get_balanced_batches(n_trials, rng, shuffle=False,
257
                                   batch_size=batch_size)
258
259
    # Check the modified batch size condition
260
    expected_n_batches = int(np.round(n_trials / float(batch_size)))
261
    assert len(batches) == expected_n_batches
262
263
    # Checking the total number of indices
264
    all_indices = np.concatenate(batches)
265
    assert len(all_indices) == n_trials
266
267
268
def test_balanced_batches_shuffle():
269
    n_trials = 50
270
    seed = 42
271
    rng = check_random_state(seed)
272
    batches_no_shuffle = get_balanced_batches(n_trials, rng, shuffle=False,
273
                                              batch_size=10)
274
    rng = check_random_state(seed)
275
    batches_with_shuffle = get_balanced_batches(n_trials, rng, shuffle=True,
276
                                                batch_size=10)
277
278
    # Check that shuffling changes the order of indices
279
    assert not np.array_equal(np.concatenate(batches_no_shuffle),
280
                              np.concatenate(batches_with_shuffle))
281
282
283
def test_corr_correlation_computation():
284
    # Create two 2D arrays with known correlation
285
    a = np.array([[1, 2, 3], [4, 5, 6]])
286
    b = np.array([[1, 2, 3], [4, 5, 6]])
287
288
    # Call the corr function
289
    corr_result = corr(a, b)
290
291
    # Compute the known correlation
292
    known_corr = np.array([np.corrcoef(a[i], b[i]) for i in range(a.shape[0])])
293
294
    # Extract the correlation computation from the corr function
295
    computed_corr = _cov_to_corr(cov(a, b), a, b)
296
297
    # Assert that the computed correlation matches the known correlation
298
    assert np.allclose(computed_corr, known_corr)
299
    assert np.allclose(corr_result, computed_corr)
300
301
302
def test_get_balanced_batches_zero_batches():
303
    # Create a scenario where n_batches is 0
304
    n_trials = 10
305
    rng = check_random_state(0)
306
    shuffle = False
307
    n_batches = 0
308
    batch_size = None
309
310
    # Call the get_balanced_batches function
311
    batches = get_balanced_batches(n_trials, rng, shuffle, n_batches, batch_size)
312
313
    # Check if the function returns a single batch with all trials
314
    assert len(batches) == 1
315
    assert len(batches[0]) == n_trials
316
317
318
def test_get_balanced_batches_i_batch_less_than_n_batches_with_extra_trial():
319
    # Create a scenario where i_batch < n_batches_with_extra_trial
320
    n_trials = 10
321
    rng = check_random_state(0)
322
    shuffle = False
323
    n_batches = 6
324
    batch_size = None
325
326
    # Call the get_balanced_batches function
327
    batches = get_balanced_batches(n_trials, rng, shuffle, n_batches, batch_size)
328
329
    # Check if the first batch has one more trial than the last batch
330
    assert len(batches[0]) > len(batches[-1])
331
332
333
def test_read_all_file_names():
334
    # Create a temporary directory
335
    with tempfile.TemporaryDirectory() as tmpdir:
336
        temp_files = []
337
        try:
338
            # Create some temporary files with .txt extension
339
            for i in range(5):
340
                temp_file = os.path.join(tmpdir, f'temp{i}.txt')
341
                with open(temp_file, 'w') as f:
342
                    f.write('This is a temporary file.')
343
                temp_files.append(temp_file)
344
345
            # Call the read_all_file_names function
346
            file_paths = read_all_file_names(tmpdir, '.txt')
347
348
            # Check if the function found all the temporary files
349
            assert len(file_paths) == 5
350
351
            # Check if the paths returned by the function are correct
352
            for i in range(5):
353
                assert os.path.join(tmpdir, f'temp{i}.txt') in file_paths
354
        finally:
355
            # Delete the temporary files
356
            for temp_file in temp_files:
357
                os.remove(temp_file)
358
359
360
def test_read_all_file_names_error():
361
    with pytest.raises(AssertionError):
362
        # Call the read_all_file_names function with a non-existent directory
363
        read_all_file_names('non_existent_dir', '.txt')
364
365
366
def test_read_all_files_not_extension():
367
    with pytest.raises(AssertionError):
368
        # Call the read_all_file_names function with a non-existent directory
369
        read_all_file_names('non_existent_dir', 'txt')