--- a
+++ b/src/training.py
@@ -0,0 +1,131 @@
+import torch
+import os
+import copy
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+
+def save_losses_graph(train_losses, valid_losses, experiment_dir):
+    epochs = range(len(train_losses))
+    plt.plot(epochs, train_losses, 'g', label='Training loss')
+    plt.plot(epochs, valid_losses, 'b', label='Validation loss')
+    plt.title('Training and Validation loss')
+    plt.xlabel('Epochs')
+    plt.ylabel('Loss')
+    plt.legend(loc = 'upper right')
+    plt.savefig(os.path.join(experiment_dir, 'train_valid_loss.png'))
+    plt.clf()
+
+def save_acc_graph(train_accs, valid_accs, experiment_dir):
+    epochs = range(len(train_accs))
+    plt.plot(epochs, train_accs, 'g', label='Training acc')
+    plt.plot(epochs, valid_accs, 'b', label='Validation acc')
+    plt.title('Training and Validation acc')
+    plt.xlabel('Epochs')
+    plt.ylabel('Acc')
+    plt.legend(loc = 'upper right')
+    plt.savefig(os.path.join(experiment_dir, 'train_valid_acc.png'))
+    plt.clf()
+
+
+def train(conf):
+    device = conf['device']
+    model = conf['model'].to(device)
+    dataloaders = conf['dataloaders']
+    criterion = conf['criterion']
+    optimizer = conf['optimizer']
+    scheduler = conf['scheduler']
+    num_epochs = conf['num_epochs']
+    experiment_dir = conf['experiment_dir']
+
+    #valid_acc_history = []
+    best_acc = 0.0
+
+    epoch_bar = tqdm(range(num_epochs), desc='Epoch',unit='epoch')
+    train_losses = []
+    valid_losses = []
+    train_accs = []
+    valid_accs = []
+    train_loss = -1.0
+    valid_loss = -1.0
+    train_acc = -1.0
+    valid_acc = -1.0
+    for epoch in epoch_bar:
+        # Each epoch has a training and validation phase
+        for phase in ['train', 'valid']:
+            if phase == 'train':
+                model.train()  # Set model to training mode
+            else:
+                model.eval()   # Set model to evaluate mode
+
+            running_loss = 0.0
+            running_corrects = 0
+
+            # Iterate over data.
+            batch_bar = tqdm(dataloaders[phase],
+                             desc='Batch',
+                             unit='batch',
+                             leave=False)
+            batch_losses = []
+            for inputs, labels in batch_bar:
+                inputs = inputs.to(device)
+                labels = labels.to(device)
+
+                # zero the parameter gradients
+                optimizer.zero_grad()
+
+                # forward
+                # track history if only in train
+                with torch.set_grad_enabled(phase == 'train'):
+                    outputs = model(inputs)
+                    loss = criterion(outputs, labels)
+
+                    _, preds = torch.max(outputs, 1)
+
+                    if phase == 'train':
+                        loss.backward()
+                        optimizer.step()
+
+                # statistics
+                running_corrects += torch.sum(preds == labels.data)
+                batch_loss = loss.item()
+                running_loss += batch_loss
+                batch_losses.append(batch_loss)
+
+
+                batch_bar.set_postfix(phase=phase, batch_loss=batch_loss)
+
+            epoch_loss = running_loss / len(dataloaders[phase])
+            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
+            if phase == 'train':
+                train_loss = epoch_loss
+                train_acc = epoch_acc.item()
+                train_losses.append(train_loss)
+                train_accs.append(train_acc)
+            else:
+                valid_loss = epoch_loss
+                valid_acc = epoch_acc.item()
+                valid_losses.append(valid_loss)
+                valid_accs.append(valid_acc)
+                save_losses_graph(train_losses, valid_losses, experiment_dir)
+                save_acc_graph(train_accs, valid_accs, experiment_dir)
+
+            if phase == 'valid' and epoch_acc > best_acc:
+                best_acc = epoch_acc
+                best_weights = copy.deepcopy(model.state_dict())
+                weights_file = 'best_weights.pt'
+                weights_path = os.path.join(experiment_dir, weights_file)
+                torch.save(best_weights, weights_path)
+
+        epoch_bar.set_postfix(tloss=train_loss,tacc=train_acc,
+                              vloss=valid_loss, vacc=valid_acc)
+        scheduler.step()
+
+    print('Best valid Acc: {:4f}'.format(best_acc))
+
+if __name__ == '__main__':
+    from config import get_config
+    # Get config from conf.yaml
+    conf = get_config('./conf/training.yaml')
+
+    # Train model
+    train(conf)