# Authors: Lukas Gemein <l.gemein@gmail.com>
#
#
# License: BSD-3
import torch
import numpy as np
from braindecode.models.tcn import TCN
from braindecode.util import set_random_seeds
def test_tcn():
set_random_seeds(0, False)
tcn = TCN(
n_chans=21,
n_outputs=2,
n_filters=55,
n_blocks=5,
kernel_size=16,
drop_prob=0.05270154233150525,
)
# braindecode models are always in eval mode after initialization
# original model implementation was not
tcn.train()
x = torch.rand(1, 21, 1000, 1)
out = tcn(x)
out = torch.nn.functional.log_softmax(out, dim=1)
# this is the output of the original model implementation using the same
# initialization arguments as above
expected = np.array(
[
[
[
-0.5504,
-0.5304,
-0.6023,
-0.5231,
-0.5387,
-0.5522,
-0.5323,
-0.5540,
-0.5297,
-0.5333,
-0.5743,
-0.5330,
-0.5117,
-0.5051,
-0.5523,
-0.5507,
-0.5724,
-0.5380,
-0.5697,
-0.4871,
-0.5400,
-0.4986,
-0.5502,
-0.5524,
-0.5263,
-0.5440,
-0.5464,
-0.5005,
-0.5404,
-0.5098,
-0.5197,
-0.5578,
-0.5419,
-0.5601,
-0.5031,
-0.5616,
-0.5205,
-0.5378,
-0.5472,
-0.4897,
-0.5216,
-0.5560,
-0.5480,
-0.5488,
-0.5258,
-0.5637,
-0.5318,
-0.5134,
-0.5460,
-0.5294,
-0.5513,
-0.5310,
-0.5307,
-0.5326,
-0.5270,
-0.5156,
-0.5569,
-0.5416,
-0.5279,
-0.5553,
-0.5589,
-0.5166,
-0.5108,
-0.5076,
-0.5279,
-0.5208,
-0.5367,
-0.5557,
-0.5690,
-0.5494,
],
[
-0.8597,
-0.8877,
-0.7931,
-0.8982,
-0.8758,
-0.8573,
-0.8849,
-0.8549,
-0.8887,
-0.8834,
-0.8280,
-0.8839,
-0.9150,
-0.9250,
-0.8572,
-0.8593,
-0.8305,
-0.8769,
-0.8340,
-0.9530,
-0.8741,
-0.9350,
-0.8600,
-0.8570,
-0.8935,
-0.8685,
-0.8652,
-0.9319,
-0.8735,
-0.9179,
-0.9031,
-0.8497,
-0.8714,
-0.8466,
-0.9280,
-0.8447,
-0.9019,
-0.8771,
-0.8640,
-0.9489,
-0.9003,
-0.8521,
-0.8630,
-0.8619,
-0.8942,
-0.8419,
-0.8856,
-0.9124,
-0.8658,
-0.8890,
-0.8585,
-0.8867,
-0.8872,
-0.8845,
-0.8924,
-0.9092,
-0.8509,
-0.8718,
-0.8912,
-0.8531,
-0.8482,
-0.9077,
-0.9163,
-0.9212,
-0.8912,
-0.9016,
-0.8787,
-0.8525,
-0.8349,
-0.8611,
],
]
]
)
np.testing.assert_allclose(out.detach().numpy(), expected, rtol=5e-2, atol=5e-2)