Diff of /main.py [000000] .. [597177]

Switch to unified view

a b/main.py
1
import glob
2
import sys
3
4
from args import get_args_parser
5
from data import get_loaders, generate_data, split_data
6
from model import eegt
7
from engine import prepare_training, train_model
8
from torchsummary import summary
9
10
if __name__ == "__main__":
11
    parser = get_args_parser()
12
    args = parser.parse_args(args=[])
13
    sys.stdout = open('logs/exp_4000_drop_5e-6.txt', 'w')
14
    model, optimizer, lr_scheduler, criterion, device, _ = prepare_training(args)
15
    print(summary(model, (59, 4000)))
16
17
    calib_files = glob.glob('data/*.mat')
18
    X, y = generate_data(calib_files)
19
    train_X, train_y, val_X, val_y, test_X, test_y = split_data(X, y)
20
    dataloaders = get_loaders(train_X, train_y, val_X, val_y, test_X, test_y)
21
    
22
    best_model = train_model(model, criterion, optimizer, lr_scheduler, device, dataloaders)