Diff of /training.py [000000] .. [64be90]

Switch to unified view

a b/training.py
1
import numpy as np
2
import pandas as pd
3
import sys, os
4
from random import shuffle
5
import torch
6
import torch.nn as nn
7
from models.gat import GATNet
8
from models.gat_gcn import GAT_GCN
9
from models.gcn import GCNNet
10
from models.ginconv import GINConvNet
11
from utils import *
12
13
# training function at each epoch
14
def train(model, device, train_loader, optimizer, epoch):
15
    print('Training on {} samples...'.format(len(train_loader.dataset)))
16
    model.train()
17
    for batch_idx, data in enumerate(train_loader):
18
        data = data.to(device)
19
        optimizer.zero_grad()
20
        output = model(data)
21
        loss = loss_fn(output, data.y.view(-1, 1).float().to(device))
22
        loss.backward()
23
        optimizer.step()
24
        if batch_idx % LOG_INTERVAL == 0:
25
            print('Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch,
26
                                                                           batch_idx * len(data.x),
27
                                                                           len(train_loader.dataset),
28
                                                                           100. * batch_idx / len(train_loader),
29
                                                                           loss.item()))
30
31
def predicting(model, device, loader):
32
    model.eval()
33
    total_preds = torch.Tensor()
34
    total_labels = torch.Tensor()
35
    print('Make prediction for {} samples...'.format(len(loader.dataset)))
36
    with torch.no_grad():
37
        for data in loader:
38
            data = data.to(device)
39
            output = model(data)
40
            total_preds = torch.cat((total_preds, output.cpu()), 0)
41
            total_labels = torch.cat((total_labels, data.y.view(-1, 1).cpu()), 0)
42
    return total_labels.numpy().flatten(),total_preds.numpy().flatten()
43
44
45
datasets = [['davis','kiba'][int(sys.argv[1])]] 
46
modeling = [GINConvNet, GATNet, GAT_GCN, GCNNet][int(sys.argv[2])]
47
model_st = modeling.__name__
48
49
cuda_name = "cuda:0"
50
if len(sys.argv)>3:
51
    cuda_name = "cuda:" + str(int(sys.argv[3])) 
52
print('cuda_name:', cuda_name)
53
54
TRAIN_BATCH_SIZE = 512
55
TEST_BATCH_SIZE = 512
56
LR = 0.0005
57
LOG_INTERVAL = 20
58
NUM_EPOCHS = 1000
59
60
print('Learning rate: ', LR)
61
print('Epochs: ', NUM_EPOCHS)
62
63
# Main program: iterate over different datasets
64
for dataset in datasets:
65
    print('\nrunning on ', model_st + '_' + dataset )
66
    processed_data_file_train = 'data/processed/' + dataset + '_train.pt'
67
    processed_data_file_test = 'data/processed/' + dataset + '_test.pt'
68
    if ((not os.path.isfile(processed_data_file_train)) or (not os.path.isfile(processed_data_file_test))):
69
        print('please run create_data.py to prepare data in pytorch format!')
70
    else:
71
        train_data = TestbedDataset(root='data', dataset=dataset+'_train')
72
        test_data = TestbedDataset(root='data', dataset=dataset+'_test')
73
        
74
        # make data PyTorch mini-batch processing ready
75
        train_loader = DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
76
        test_loader = DataLoader(test_data, batch_size=TEST_BATCH_SIZE, shuffle=False)
77
78
        # training the model
79
        device = torch.device(cuda_name if torch.cuda.is_available() else "cpu")
80
        model = modeling().to(device)
81
        loss_fn = nn.MSELoss()
82
        optimizer = torch.optim.Adam(model.parameters(), lr=LR)
83
        best_mse = 1000
84
        best_ci = 0
85
        best_epoch = -1
86
        model_file_name = 'model_' + model_st + '_' + dataset +  '.model'
87
        result_file_name = 'result_' + model_st + '_' + dataset +  '.csv'
88
        for epoch in range(NUM_EPOCHS):
89
            train(model, device, train_loader, optimizer, epoch+1)
90
            G,P = predicting(model, device, test_loader)
91
            ret = [rmse(G,P),mse(G,P),pearson(G,P),spearman(G,P),ci(G,P)]
92
            if ret[1]<best_mse:
93
                torch.save(model.state_dict(), model_file_name)
94
                with open(result_file_name,'w') as f:
95
                    f.write(','.join(map(str,ret)))
96
                best_epoch = epoch+1
97
                best_mse = ret[1]
98
                best_ci = ret[-1]
99
                print('rmse improved at epoch ', best_epoch, '; best_mse,best_ci:', best_mse,best_ci,model_st,dataset)
100
            else:
101
                print(ret[1],'No improvement since epoch ', best_epoch, '; best_mse,best_ci:', best_mse,best_ci,model_st,dataset)
102