Diff of /engine.py [000000] .. [597177]

Switch to unified view

a b/engine.py
1
import sys, time, copy
2
import torch
3
import torch.nn as nn
4
import tqdm
5
6
from timm.models import create_model
7
from timm.scheduler import create_scheduler
8
from timm.optim import create_optimizer
9
from timm.utils import NativeScaler
10
11
def prepare_training(args):
12
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
13
    model = create_model('eegt').to(device)
14
    optimizer = create_optimizer(args, model)
15
    lr_scheduler, _ = create_scheduler(args, optimizer)
16
    criterion = nn.CrossEntropyLoss()
17
    loss_scaler = NativeScaler()
18
    print(device)
19
    return model, optimizer, lr_scheduler, criterion, device, loss_scaler
20
21
def train_model(model, criterion, optimizer, scheduler, device, dataloaders, args={'dataset_sizes': {'train': 1000, 'val': 197, 'test':200}}):
22
    since = time.time()
23
24
    best_model_wts = copy.deepcopy(model.state_dict())
25
    best_acc = 0.0
26
    for epoch in range(50):
27
        sys.stdout.flush()
28
        print('Epoch {}/{}'.format(epoch+1, 50))
29
        print('-' * 10)
30
31
        # Each epoch has a training and validation phase
32
        for phase in ['train', 'val', 'test']:
33
            if phase == 'train':
34
                model.train()  # Set model to training mode
35
            else:
36
                model.eval()   # Set model to evaluate mode
37
38
            running_loss = 0.0
39
            running_corrects = 0
40
41
            # Iterate over data.
42
            for inputs, labels in tqdm.tqdm(dataloaders[phase]):
43
                inputs = inputs.type(torch.cuda.FloatTensor).to(device)
44
                labels = labels.type(torch.cuda.LongTensor).to(device).squeeze(1)
45
                # print(labels)
46
                # zero the parameter gradients
47
                optimizer.zero_grad()
48
49
                # forward
50
                # track history if only in train
51
                with torch.set_grad_enabled(phase == 'train'):
52
                    outputs = model(inputs)
53
                    _, preds = torch.max(outputs, 1)
54
                    loss = criterion(outputs, labels)
55
56
                    # backward + optimize only if in training phase
57
                    if phase == 'train':
58
                        loss.backward()
59
                        optimizer.step()
60
61
                # statistics
62
                running_loss += loss.item() * inputs.size(0)
63
                running_corrects += torch.sum(preds == labels.data)
64
            if phase == 'train':
65
                scheduler.step(epoch=epoch)
66
67
            epoch_loss = running_loss / args['dataset_sizes'][phase]
68
            epoch_acc = running_corrects.double() / args['dataset_sizes'][phase]
69
70
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
71
                phase, epoch_loss, epoch_acc))
72
73
            # deep copy the model
74
            if phase == 'val' and epoch_acc > best_acc:
75
                best_acc = epoch_acc
76
                best_model_wts = copy.deepcopy(model.state_dict())
77
78
        print()
79
80
    time_elapsed = time.time() - since
81
    print('Training complete in {:.0f}m {:.0f}s'.format(
82
        time_elapsed // 60, time_elapsed % 60))
83
    print('Best val Acc: {:4f}'.format(best_acc))
84
85
    # load best model weights
86
    model.load_state_dict(best_model_wts)
87
    return model