Diff of /engine.py [000000] .. [597177]

Switch to side-by-side view

--- a
+++ b/engine.py
@@ -0,0 +1,87 @@
+import sys, time, copy
+import torch
+import torch.nn as nn
+import tqdm
+
+from timm.models import create_model
+from timm.scheduler import create_scheduler
+from timm.optim import create_optimizer
+from timm.utils import NativeScaler
+
+def prepare_training(args):
+    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+    model = create_model('eegt').to(device)
+    optimizer = create_optimizer(args, model)
+    lr_scheduler, _ = create_scheduler(args, optimizer)
+    criterion = nn.CrossEntropyLoss()
+    loss_scaler = NativeScaler()
+    print(device)
+    return model, optimizer, lr_scheduler, criterion, device, loss_scaler
+
+def train_model(model, criterion, optimizer, scheduler, device, dataloaders, args={'dataset_sizes': {'train': 1000, 'val': 197, 'test':200}}):
+    since = time.time()
+
+    best_model_wts = copy.deepcopy(model.state_dict())
+    best_acc = 0.0
+    for epoch in range(50):
+        sys.stdout.flush()
+        print('Epoch {}/{}'.format(epoch+1, 50))
+        print('-' * 10)
+
+        # Each epoch has a training and validation phase
+        for phase in ['train', 'val', 'test']:
+            if phase == 'train':
+                model.train()  # Set model to training mode
+            else:
+                model.eval()   # Set model to evaluate mode
+
+            running_loss = 0.0
+            running_corrects = 0
+
+            # Iterate over data.
+            for inputs, labels in tqdm.tqdm(dataloaders[phase]):
+                inputs = inputs.type(torch.cuda.FloatTensor).to(device)
+                labels = labels.type(torch.cuda.LongTensor).to(device).squeeze(1)
+                # print(labels)
+                # zero the parameter gradients
+                optimizer.zero_grad()
+
+                # forward
+                # track history if only in train
+                with torch.set_grad_enabled(phase == 'train'):
+                    outputs = model(inputs)
+                    _, preds = torch.max(outputs, 1)
+                    loss = criterion(outputs, labels)
+
+                    # backward + optimize only if in training phase
+                    if phase == 'train':
+                        loss.backward()
+                        optimizer.step()
+
+                # statistics
+                running_loss += loss.item() * inputs.size(0)
+                running_corrects += torch.sum(preds == labels.data)
+            if phase == 'train':
+                scheduler.step(epoch=epoch)
+
+            epoch_loss = running_loss / args['dataset_sizes'][phase]
+            epoch_acc = running_corrects.double() / args['dataset_sizes'][phase]
+
+            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
+                phase, epoch_loss, epoch_acc))
+
+            # deep copy the model
+            if phase == 'val' and epoch_acc > best_acc:
+                best_acc = epoch_acc
+                best_model_wts = copy.deepcopy(model.state_dict())
+
+        print()
+
+    time_elapsed = time.time() - since
+    print('Training complete in {:.0f}m {:.0f}s'.format(
+        time_elapsed // 60, time_elapsed % 60))
+    print('Best val Acc: {:4f}'.format(best_acc))
+
+    # load best model weights
+    model.load_state_dict(best_model_wts)
+    return model