|
a |
|
b/main.py |
|
|
1 |
import glob |
|
|
2 |
import sys |
|
|
3 |
|
|
|
4 |
from args import get_args_parser |
|
|
5 |
from data import get_loaders, generate_data, split_data |
|
|
6 |
from model import eegt |
|
|
7 |
from engine import prepare_training, train_model |
|
|
8 |
from torchsummary import summary |
|
|
9 |
|
|
|
10 |
if __name__ == "__main__": |
|
|
11 |
parser = get_args_parser() |
|
|
12 |
args = parser.parse_args(args=[]) |
|
|
13 |
sys.stdout = open('logs/exp_4000_drop_5e-6.txt', 'w') |
|
|
14 |
model, optimizer, lr_scheduler, criterion, device, _ = prepare_training(args) |
|
|
15 |
print(summary(model, (59, 4000))) |
|
|
16 |
|
|
|
17 |
calib_files = glob.glob('data/*.mat') |
|
|
18 |
X, y = generate_data(calib_files) |
|
|
19 |
train_X, train_y, val_X, val_y, test_X, test_y = split_data(X, y) |
|
|
20 |
dataloaders = get_loaders(train_X, train_y, val_X, val_y, test_X, test_y) |
|
|
21 |
|
|
|
22 |
best_model = train_model(model, criterion, optimizer, lr_scheduler, device, dataloaders) |