a b/training_validation.py
1
import numpy as np
2
import pandas as pd
3
import sys, os
4
from random import shuffle
5
import torch
6
import torch.nn as nn
7
from models.gat import GATNet
8
from models.gat_gcn import GAT_GCN
9
from models.gcn import GCNNet
10
from models.ginconv import GINConvNet
11
from utils import *
12
13
# training function at each epoch
14
def train(model, device, train_loader, optimizer, epoch):
15
    print('Training on {} samples...'.format(len(train_loader.dataset)))
16
    model.train()
17
    for batch_idx, data in enumerate(train_loader):
18
        data = data.to(device)
19
        optimizer.zero_grad()
20
        output = model(data)
21
        loss = loss_fn(output, data.y.view(-1, 1).float().to(device))
22
        loss.backward()
23
        optimizer.step()
24
        if batch_idx % LOG_INTERVAL == 0:
25
            print('Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch,
26
                                                                           batch_idx * len(data.x),
27
                                                                           len(train_loader.dataset),
28
                                                                           100. * batch_idx / len(train_loader),
29
                                                                           loss.item()))
30
31
def predicting(model, device, loader):
32
    model.eval()
33
    total_preds = torch.Tensor()
34
    total_labels = torch.Tensor()
35
    print('Make prediction for {} samples...'.format(len(loader.dataset)))
36
    with torch.no_grad():
37
        for data in loader:
38
            data = data.to(device)
39
            output = model(data)
40
            total_preds = torch.cat((total_preds, output.cpu()), 0)
41
            total_labels = torch.cat((total_labels, data.y.view(-1, 1).cpu()), 0)
42
    return total_labels.numpy().flatten(),total_preds.numpy().flatten()
43
44
45
datasets = [['davis','kiba'][int(sys.argv[1])]] 
46
modeling = [GINConvNet, GATNet, GAT_GCN, GCNNet][int(sys.argv[2])]
47
model_st = modeling.__name__
48
49
cuda_name = "cuda:0"
50
if len(sys.argv)>3:
51
    cuda_name = ["cuda:0","cuda:1"][int(sys.argv[3])]
52
print('cuda_name:', cuda_name)
53
54
TRAIN_BATCH_SIZE = 512
55
TEST_BATCH_SIZE = 512
56
LR = 0.0005
57
LOG_INTERVAL = 20
58
NUM_EPOCHS = 1000
59
60
print('Learning rate: ', LR)
61
print('Epochs: ', NUM_EPOCHS)
62
63
# Main program: iterate over different datasets
64
for dataset in datasets:
65
    print('\nrunning on ', model_st + '_' + dataset )
66
    processed_data_file_train = 'data/processed/' + dataset + '_train.pt'
67
    processed_data_file_test = 'data/processed/' + dataset + '_test.pt'
68
    if ((not os.path.isfile(processed_data_file_train)) or (not os.path.isfile(processed_data_file_test))):
69
        print('please run create_data.py to prepare data in pytorch format!')
70
    else:
71
        train_data = TestbedDataset(root='data', dataset=dataset+'_train')
72
        test_data = TestbedDataset(root='data', dataset=dataset+'_test')
73
        
74
        
75
        train_size = int(0.8 * len(train_data))
76
        valid_size = len(train_data) - train_size
77
        train_data, valid_data = torch.utils.data.random_split(train_data, [train_size, valid_size])        
78
        
79
        
80
        # make data PyTorch mini-batch processing ready
81
        train_loader = DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
82
        valid_loader = DataLoader(valid_data, batch_size=TEST_BATCH_SIZE, shuffle=False)
83
        test_loader = DataLoader(test_data, batch_size=TEST_BATCH_SIZE, shuffle=False)
84
85
        # training the model
86
        device = torch.device(cuda_name if torch.cuda.is_available() else "cpu")
87
        model = modeling().to(device)
88
        loss_fn = nn.MSELoss()
89
        optimizer = torch.optim.Adam(model.parameters(), lr=LR)
90
        best_mse = 1000
91
        best_test_mse = 1000
92
        best_test_ci = 0
93
        best_epoch = -1
94
        model_file_name = 'model_' + model_st + '_' + dataset +  '.model'
95
        result_file_name = 'result_' + model_st + '_' + dataset +  '.csv'
96
        for epoch in range(NUM_EPOCHS):
97
            train(model, device, train_loader, optimizer, epoch+1)
98
            print('predicting for valid data')
99
            G,P = predicting(model, device, valid_loader)
100
            val = mse(G,P)
101
            if val<best_mse:
102
                best_mse = val
103
                best_epoch = epoch+1
104
                torch.save(model.state_dict(), model_file_name)
105
                print('predicting for test data')
106
                G,P = predicting(model, device, test_loader)
107
                ret = [rmse(G,P),mse(G,P),pearson(G,P),spearman(G,P),ci(G,P)]
108
                with open(result_file_name,'w') as f:
109
                    f.write(','.join(map(str,ret)))
110
                best_test_mse = ret[1]
111
                best_test_ci = ret[-1]
112
                print('rmse improved at epoch ', best_epoch, '; best_test_mse,best_test_ci:', best_test_mse,best_test_ci,model_st,dataset)
113
            else:
114
                print(ret[1],'No improvement since epoch ', best_epoch, '; best_test_mse,best_test_ci:', best_test_mse,best_test_ci,model_st,dataset)
115