|
a |
|
b/engine.py |
|
|
1 |
import sys, time, copy |
|
|
2 |
import torch |
|
|
3 |
import torch.nn as nn |
|
|
4 |
import tqdm |
|
|
5 |
|
|
|
6 |
from timm.models import create_model |
|
|
7 |
from timm.scheduler import create_scheduler |
|
|
8 |
from timm.optim import create_optimizer |
|
|
9 |
from timm.utils import NativeScaler |
|
|
10 |
|
|
|
11 |
def prepare_training(args): |
|
|
12 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
13 |
model = create_model('eegt').to(device) |
|
|
14 |
optimizer = create_optimizer(args, model) |
|
|
15 |
lr_scheduler, _ = create_scheduler(args, optimizer) |
|
|
16 |
criterion = nn.CrossEntropyLoss() |
|
|
17 |
loss_scaler = NativeScaler() |
|
|
18 |
print(device) |
|
|
19 |
return model, optimizer, lr_scheduler, criterion, device, loss_scaler |
|
|
20 |
|
|
|
21 |
def train_model(model, criterion, optimizer, scheduler, device, dataloaders, args={'dataset_sizes': {'train': 1000, 'val': 197, 'test':200}}): |
|
|
22 |
since = time.time() |
|
|
23 |
|
|
|
24 |
best_model_wts = copy.deepcopy(model.state_dict()) |
|
|
25 |
best_acc = 0.0 |
|
|
26 |
for epoch in range(50): |
|
|
27 |
sys.stdout.flush() |
|
|
28 |
print('Epoch {}/{}'.format(epoch+1, 50)) |
|
|
29 |
print('-' * 10) |
|
|
30 |
|
|
|
31 |
# Each epoch has a training and validation phase |
|
|
32 |
for phase in ['train', 'val', 'test']: |
|
|
33 |
if phase == 'train': |
|
|
34 |
model.train() # Set model to training mode |
|
|
35 |
else: |
|
|
36 |
model.eval() # Set model to evaluate mode |
|
|
37 |
|
|
|
38 |
running_loss = 0.0 |
|
|
39 |
running_corrects = 0 |
|
|
40 |
|
|
|
41 |
# Iterate over data. |
|
|
42 |
for inputs, labels in tqdm.tqdm(dataloaders[phase]): |
|
|
43 |
inputs = inputs.type(torch.cuda.FloatTensor).to(device) |
|
|
44 |
labels = labels.type(torch.cuda.LongTensor).to(device).squeeze(1) |
|
|
45 |
# print(labels) |
|
|
46 |
# zero the parameter gradients |
|
|
47 |
optimizer.zero_grad() |
|
|
48 |
|
|
|
49 |
# forward |
|
|
50 |
# track history if only in train |
|
|
51 |
with torch.set_grad_enabled(phase == 'train'): |
|
|
52 |
outputs = model(inputs) |
|
|
53 |
_, preds = torch.max(outputs, 1) |
|
|
54 |
loss = criterion(outputs, labels) |
|
|
55 |
|
|
|
56 |
# backward + optimize only if in training phase |
|
|
57 |
if phase == 'train': |
|
|
58 |
loss.backward() |
|
|
59 |
optimizer.step() |
|
|
60 |
|
|
|
61 |
# statistics |
|
|
62 |
running_loss += loss.item() * inputs.size(0) |
|
|
63 |
running_corrects += torch.sum(preds == labels.data) |
|
|
64 |
if phase == 'train': |
|
|
65 |
scheduler.step(epoch=epoch) |
|
|
66 |
|
|
|
67 |
epoch_loss = running_loss / args['dataset_sizes'][phase] |
|
|
68 |
epoch_acc = running_corrects.double() / args['dataset_sizes'][phase] |
|
|
69 |
|
|
|
70 |
print('{} Loss: {:.4f} Acc: {:.4f}'.format( |
|
|
71 |
phase, epoch_loss, epoch_acc)) |
|
|
72 |
|
|
|
73 |
# deep copy the model |
|
|
74 |
if phase == 'val' and epoch_acc > best_acc: |
|
|
75 |
best_acc = epoch_acc |
|
|
76 |
best_model_wts = copy.deepcopy(model.state_dict()) |
|
|
77 |
|
|
|
78 |
print() |
|
|
79 |
|
|
|
80 |
time_elapsed = time.time() - since |
|
|
81 |
print('Training complete in {:.0f}m {:.0f}s'.format( |
|
|
82 |
time_elapsed // 60, time_elapsed % 60)) |
|
|
83 |
print('Best val Acc: {:4f}'.format(best_acc)) |
|
|
84 |
|
|
|
85 |
# load best model weights |
|
|
86 |
model.load_state_dict(best_model_wts) |
|
|
87 |
return model |