import time
import copy
import torch
from torchnet import meter
from torch.autograd import Variable
from utils import plot_training
data_cat = ['train', 'valid'] # data categories
def train_model(model, criterion, optimizer, dataloaders, scheduler,
dataset_sizes, num_epochs):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
costs = {x:[] for x in data_cat} # for storing costs per epoch
accs = {x:[] for x in data_cat} # for storing accuracies per epoch
print('Train batches:', len(dataloaders['train']))
print('Valid batches:', len(dataloaders['valid']), '\n')
for epoch in range(num_epochs):
confusion_matrix = {x: meter.ConfusionMeter(2, normalized=True)
for x in data_cat}
print('Epoch {}/{}'.format(epoch+1, num_epochs))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in data_cat:
model.train(phase=='train')
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for i, data in enumerate(dataloaders[phase]):
# get the inputs
print(i, end='\r')
inputs = data['images'][0]
labels = data['label'].type(torch.FloatTensor)
# wrap them in Variable
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
# zero the parameter gradients
optimizer.zero_grad()
# forward
outputs = model(inputs)
outputs = torch.mean(outputs)
loss = criterion(outputs, labels, phase)
running_loss += loss.data[0]
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
preds = (outputs.data > 0.5).type(torch.cuda.FloatTensor)
running_corrects += torch.sum(preds == labels.data)
confusion_matrix[phase].add(preds, labels.data)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]
costs[phase].append(epoch_loss)
accs[phase].append(epoch_acc)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
print('Confusion Meter:\n', confusion_matrix[phase].value())
# deep copy the model
if phase == 'valid':
scheduler.step(epoch_loss)
if epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
time_elapsed = time.time() - since
print('Time elapsed: {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best valid Acc: {:4f}'.format(best_acc))
plot_training(costs, accs)
# load best model weights
model.load_state_dict(best_model_wts)
return model
def get_metrics(model, criterion, dataloaders, dataset_sizes, phase='valid'):
'''
Loops over phase (train or valid) set to determine acc, loss and
confusion meter of the model.
'''
confusion_matrix = meter.ConfusionMeter(2, normalized=True)
running_loss = 0.0
running_corrects = 0
for i, data in enumerate(dataloaders[phase]):
print(i, end='\r')
labels = data['label'].type(torch.FloatTensor)
inputs = data['images'][0]
# wrap them in Variable
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
# forward
outputs = model(inputs)
outputs = torch.mean(outputs)
loss = criterion(outputs, labels, phase)
# statistics
running_loss += loss.data[0] * inputs.size(0)
preds = (outputs.data > 0.5).type(torch.cuda.FloatTensor)
running_corrects += torch.sum(preds == labels.data)
confusion_matrix.add(preds, labels.data)
loss = running_loss / dataset_sizes[phase]
acc = running_corrects / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, loss, acc))
print('Confusion Meter:\n', confusion_matrix.value())