Diff of /cnnmodel/training.py [000000] .. [8c4e02]

Switch to side-by-side view

--- a
+++ b/cnnmodel/training.py
@@ -0,0 +1,186 @@
+import sys
+import time
+import tqdm
+import numpy as np
+import pandas as pd
+import torch
+
+from torch import optim
+from torch.utils.data import DataLoader
+
+from cnnmodel.model import CNNStressNet
+from cnnmodel.dataset import CNNDataset
+
+from util.pt_util import restore_objects, save_model, save_objects, restore_model
+
+
+def update_metrics(pred: torch.Tensor, label: torch.Tensor, metric_dict: dict):
+    metric_dict['accuracy'] += torch.sum((pred == label)).item()
+    metric_dict['true_pos'] += torch.sum((label == 1) & (pred == 1)).item()
+    metric_dict['true_neg'] += torch.sum((label == 0) & (pred == 0)).item()
+    metric_dict['false_pos'] += torch.sum((label == 0) & (pred == 1)).item()
+    metric_dict['false_neg'] += torch.sum((label == 1) & (pred == 0)).item()
+
+
+def train(model, device, train_loader, optimizer, epoch, log_interval):
+    model.train()
+    losses = []
+    metric_dict = {
+        'accuracy': 0,
+        'true_pos': 0,
+        'true_neg': 0,
+        'false_pos': 0,
+        'false_neg': 0
+    }
+
+    for batch_idx, ((mfcc, non_mfcc, path), label) in enumerate(tqdm.tqdm(train_loader)):
+        mfcc, non_mfcc, label = mfcc.to(device), non_mfcc.to(device), label.to(device)
+        optimizer.zero_grad()
+        out = model(mfcc, non_mfcc)
+        loss = model.loss(out, label)
+        with torch.no_grad():
+            prob = torch.nn.functional.softmax(out, dim=1)
+            pred = torch.argmax(prob, dim=1)
+            update_metrics(pred=pred, label=label, metric_dict=metric_dict)
+
+        losses.append(loss.item())
+        loss.backward()
+        optimizer.step()
+
+        if batch_idx % log_interval == 0:
+            print('{} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
+                time.ctime(time.time()),
+                epoch, batch_idx * len(mfcc), len(train_loader.dataset),
+                100. * batch_idx / len(train_loader), loss.item()))
+
+    accuracy_mean = (100. * metric_dict['accuracy']) / len(train_loader.dataset)
+
+    metric_dict['batch_losses'] = losses
+    metric_dict['accuracy_mean'] = accuracy_mean
+    metric_dict['precision'] = (metric_dict["true_pos"]) / (metric_dict["true_pos"] + metric_dict["false_pos"])
+    metric_dict['recall'] = (metric_dict["true_pos"]) / (metric_dict["true_pos"] + metric_dict["false_neg"])
+    metric_dict['f1_score'] = (2.0 * metric_dict['precision'] * metric_dict['recall']) / \
+                              (metric_dict['precision'] + metric_dict['recall'])
+
+    return np.mean(losses), accuracy_mean, metric_dict
+
+
+def test(model, device, test_loader, log_interval=None):
+    model.eval()
+    losses = []
+
+    metric_dict = {
+        'accuracy': 0,
+        'true_pos': 0,
+        'true_neg': 0,
+        'false_pos': 0,
+        'false_neg': 0
+    }
+
+    data_check_dict = {'path': [], 'label': [], 'pred': [], 'prob_0': [], 'prob_1': []}
+
+    with torch.no_grad():
+        for batch_idx, ((mfcc, non_mfcc, path), label) in enumerate(tqdm.tqdm(test_loader)):
+            mfcc, non_mfcc, label = mfcc.to(device), non_mfcc.to(device), label.to(device)
+            out = model(mfcc, non_mfcc)
+            prob = torch.nn.functional.softmax(out, dim=1)
+            test_loss_on = model.loss(out, label).item()
+            losses.append(test_loss_on)
+
+            pred = torch.argmax(prob, dim=1)
+            update_metrics(pred=pred, label=label, metric_dict=metric_dict)
+
+            if log_interval is not None and batch_idx % log_interval == 0:
+                print('{} Test: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
+                    time.ctime(time.time()),
+                    batch_idx * len(mfcc), len(test_loader.dataset),
+                    100. * batch_idx / len(test_loader), test_loss_on))
+
+            data_check_dict['path'] += path
+            data_check_dict['label'] += label.tolist()
+            data_check_dict['pred'] += pred.tolist()
+            data_check_dict['prob_0'] += prob[:, 0].tolist()
+            data_check_dict['prob_1'] += prob[:, 1].tolist()
+
+    data_check_df = pd.DataFrame(data_check_dict)
+    data_check_df.to_csv('data_check_test.csv', index=False)
+
+    test_loss = np.mean(losses)
+    accuracy_mean = (100. * metric_dict['accuracy']) / len(test_loader.dataset)
+
+    metric_dict['batch_losses'] = losses
+    metric_dict['accuracy_mean'] = accuracy_mean
+    metric_dict['precision'] = (metric_dict["true_pos"]) / (metric_dict["true_pos"] + metric_dict["false_pos"])
+    metric_dict['recall'] = (metric_dict["true_pos"]) / (metric_dict["true_pos"] + metric_dict["false_neg"])
+    metric_dict['f1_score'] = (2.0 * metric_dict['precision'] * metric_dict['recall']) / \
+                              (metric_dict['precision'] + metric_dict['recall'])
+
+    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{}, ({:.4f})%\n'.format(
+        test_loss, metric_dict['accuracy'], len(test_loader.dataset), accuracy_mean))
+
+    return test_loss, accuracy_mean, metric_dict
+
+
+def main(train_path, test_path, model_path, learning_rate, epochs):
+    print('train path: {}'.format(train_path))
+    print('test path: {}'.format(test_path))
+    print('model path: {}'.format(model_path))
+
+    use_cuda = torch.cuda.is_available()
+    device = torch.device("cuda" if use_cuda else "cpu")
+    torch.cuda.current_device()
+    print('using device', device)
+
+    import multiprocessing
+    print('num cpus:', multiprocessing.cpu_count())
+
+    kwargs = {'num_workers': multiprocessing.cpu_count(),
+              'pin_memory': True} if use_cuda else {}
+
+    train_dataset = CNNDataset(root=train_path)
+    train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, **kwargs)
+
+    test_dataset = CNNDataset(root=test_path)
+    test_loader = DataLoader(test_dataset, batch_size=512, shuffle=True, **kwargs)
+
+    print('Folder to Index: {}'.format(train_dataset.folder_to_index))
+
+    model = CNNStressNet(reduction='mean').to(device)
+    model = restore_model(model, model_path)
+    last_epoch, max_accuracy, train_losses, test_losses, all_train_metrics, all_test_metrics = \
+        restore_objects(model_path, (0, 0, [], [], [], []))
+
+    start = last_epoch + 1 if max_accuracy > 0 else 0
+
+    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
+
+    test_loss, test_accuracy, test_metrics = test(model, device, test_loader)
+    print('Before any training:, test loss is: {}, test_accuracy: {}'.format(test_loss, test_accuracy))
+
+    for epoch in range(start, start + epochs):
+        train_loss, train_accuracy, train_metrics = train(model, device, train_loader, optimizer, epoch, 250)
+        test_loss, test_accuracy, test_metrics = test(model, device, test_loader)
+        print('After epoch: {}, train_loss: {}, test loss is: {}, train_accuracy: {}, test_accuracy: {}'.format(
+            epoch, train_loss, test_loss, train_accuracy, test_accuracy))
+
+        train_losses.append(train_loss)
+        test_losses.append(test_loss)
+        all_train_metrics.append(train_metrics)
+        all_test_metrics.append(test_metrics)
+
+        if test_accuracy > max_accuracy:
+            max_accuracy = test_accuracy
+            save_model(model, epoch, model_path)
+            save_objects((epoch, max_accuracy, train_losses, test_losses, all_train_metrics, all_test_metrics),
+                         epoch, model_path)
+            print('saved epoch: {} as checkpoint'.format(epoch))
+
+
+if __name__ == '__main__':
+    # needs three command line arguments
+    # 1. root path of train data
+    # 2. root path of test data
+    # 3. path where saved models are saved
+    # 4. Learning rate
+    # 5. Number of epochs
+    main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5])