--- a +++ b/train.py @@ -0,0 +1,108 @@ +import time +import copy +import torch +from torchnet import meter +from torch.autograd import Variable +from utils import plot_training + +data_cat = ['train', 'valid'] # data categories + +def train_model(model, criterion, optimizer, dataloaders, scheduler, + dataset_sizes, num_epochs): + since = time.time() + best_model_wts = copy.deepcopy(model.state_dict()) + best_acc = 0.0 + costs = {x:[] for x in data_cat} # for storing costs per epoch + accs = {x:[] for x in data_cat} # for storing accuracies per epoch + print('Train batches:', len(dataloaders['train'])) + print('Valid batches:', len(dataloaders['valid']), '\n') + for epoch in range(num_epochs): + confusion_matrix = {x: meter.ConfusionMeter(2, normalized=True) + for x in data_cat} + print('Epoch {}/{}'.format(epoch+1, num_epochs)) + print('-' * 10) + # Each epoch has a training and validation phase + for phase in data_cat: + model.train(phase=='train') + running_loss = 0.0 + running_corrects = 0 + # Iterate over data. + for i, data in enumerate(dataloaders[phase]): + # get the inputs + print(i, end='\r') + inputs = data['images'][0] + labels = data['label'].type(torch.FloatTensor) + # wrap them in Variable + inputs = Variable(inputs.cuda()) + labels = Variable(labels.cuda()) + # zero the parameter gradients + optimizer.zero_grad() + # forward + outputs = model(inputs) + outputs = torch.mean(outputs) + loss = criterion(outputs, labels, phase) + running_loss += loss.data[0] + # backward + optimize only if in training phase + if phase == 'train': + loss.backward() + optimizer.step() + # statistics + preds = (outputs.data > 0.5).type(torch.cuda.FloatTensor) + running_corrects += torch.sum(preds == labels.data) + confusion_matrix[phase].add(preds, labels.data) + epoch_loss = running_loss / dataset_sizes[phase] + epoch_acc = running_corrects / dataset_sizes[phase] + costs[phase].append(epoch_loss) + accs[phase].append(epoch_acc) + print('{} Loss: {:.4f} Acc: {:.4f}'.format( + phase, epoch_loss, epoch_acc)) + print('Confusion Meter:\n', confusion_matrix[phase].value()) + # deep copy the model + if phase == 'valid': + scheduler.step(epoch_loss) + if epoch_acc > best_acc: + best_acc = epoch_acc + best_model_wts = copy.deepcopy(model.state_dict()) + time_elapsed = time.time() - since + print('Time elapsed: {:.0f}m {:.0f}s'.format( + time_elapsed // 60, time_elapsed % 60)) + print() + time_elapsed = time.time() - since + print('Training complete in {:.0f}m {:.0f}s'.format( + time_elapsed // 60, time_elapsed % 60)) + print('Best valid Acc: {:4f}'.format(best_acc)) + plot_training(costs, accs) + # load best model weights + model.load_state_dict(best_model_wts) + return model + + +def get_metrics(model, criterion, dataloaders, dataset_sizes, phase='valid'): + ''' + Loops over phase (train or valid) set to determine acc, loss and + confusion meter of the model. + ''' + confusion_matrix = meter.ConfusionMeter(2, normalized=True) + running_loss = 0.0 + running_corrects = 0 + for i, data in enumerate(dataloaders[phase]): + print(i, end='\r') + labels = data['label'].type(torch.FloatTensor) + inputs = data['images'][0] + # wrap them in Variable + inputs = Variable(inputs.cuda()) + labels = Variable(labels.cuda()) + # forward + outputs = model(inputs) + outputs = torch.mean(outputs) + loss = criterion(outputs, labels, phase) + # statistics + running_loss += loss.data[0] * inputs.size(0) + preds = (outputs.data > 0.5).type(torch.cuda.FloatTensor) + running_corrects += torch.sum(preds == labels.data) + confusion_matrix.add(preds, labels.data) + + loss = running_loss / dataset_sizes[phase] + acc = running_corrects / dataset_sizes[phase] + print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, loss, acc)) + print('Confusion Meter:\n', confusion_matrix.value())