|
a |
|
b/training.py |
|
|
1 |
import numpy as np |
|
|
2 |
import pandas as pd |
|
|
3 |
import sys, os |
|
|
4 |
from random import shuffle |
|
|
5 |
import torch |
|
|
6 |
import torch.nn as nn |
|
|
7 |
from models.gat import GATNet |
|
|
8 |
from models.gat_gcn import GAT_GCN |
|
|
9 |
from models.gcn import GCNNet |
|
|
10 |
from models.ginconv import GINConvNet |
|
|
11 |
from utils import * |
|
|
12 |
|
|
|
13 |
# training function at each epoch |
|
|
14 |
def train(model, device, train_loader, optimizer, epoch): |
|
|
15 |
print('Training on {} samples...'.format(len(train_loader.dataset))) |
|
|
16 |
model.train() |
|
|
17 |
for batch_idx, data in enumerate(train_loader): |
|
|
18 |
data = data.to(device) |
|
|
19 |
optimizer.zero_grad() |
|
|
20 |
output = model(data) |
|
|
21 |
loss = loss_fn(output, data.y.view(-1, 1).float().to(device)) |
|
|
22 |
loss.backward() |
|
|
23 |
optimizer.step() |
|
|
24 |
if batch_idx % LOG_INTERVAL == 0: |
|
|
25 |
print('Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, |
|
|
26 |
batch_idx * len(data.x), |
|
|
27 |
len(train_loader.dataset), |
|
|
28 |
100. * batch_idx / len(train_loader), |
|
|
29 |
loss.item())) |
|
|
30 |
|
|
|
31 |
def predicting(model, device, loader): |
|
|
32 |
model.eval() |
|
|
33 |
total_preds = torch.Tensor() |
|
|
34 |
total_labels = torch.Tensor() |
|
|
35 |
print('Make prediction for {} samples...'.format(len(loader.dataset))) |
|
|
36 |
with torch.no_grad(): |
|
|
37 |
for data in loader: |
|
|
38 |
data = data.to(device) |
|
|
39 |
output = model(data) |
|
|
40 |
total_preds = torch.cat((total_preds, output.cpu()), 0) |
|
|
41 |
total_labels = torch.cat((total_labels, data.y.view(-1, 1).cpu()), 0) |
|
|
42 |
return total_labels.numpy().flatten(),total_preds.numpy().flatten() |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
datasets = [['davis','kiba'][int(sys.argv[1])]] |
|
|
46 |
modeling = [GINConvNet, GATNet, GAT_GCN, GCNNet][int(sys.argv[2])] |
|
|
47 |
model_st = modeling.__name__ |
|
|
48 |
|
|
|
49 |
cuda_name = "cuda:0" |
|
|
50 |
if len(sys.argv)>3: |
|
|
51 |
cuda_name = "cuda:" + str(int(sys.argv[3])) |
|
|
52 |
print('cuda_name:', cuda_name) |
|
|
53 |
|
|
|
54 |
TRAIN_BATCH_SIZE = 512 |
|
|
55 |
TEST_BATCH_SIZE = 512 |
|
|
56 |
LR = 0.0005 |
|
|
57 |
LOG_INTERVAL = 20 |
|
|
58 |
NUM_EPOCHS = 1000 |
|
|
59 |
|
|
|
60 |
print('Learning rate: ', LR) |
|
|
61 |
print('Epochs: ', NUM_EPOCHS) |
|
|
62 |
|
|
|
63 |
# Main program: iterate over different datasets |
|
|
64 |
for dataset in datasets: |
|
|
65 |
print('\nrunning on ', model_st + '_' + dataset ) |
|
|
66 |
processed_data_file_train = 'data/processed/' + dataset + '_train.pt' |
|
|
67 |
processed_data_file_test = 'data/processed/' + dataset + '_test.pt' |
|
|
68 |
if ((not os.path.isfile(processed_data_file_train)) or (not os.path.isfile(processed_data_file_test))): |
|
|
69 |
print('please run create_data.py to prepare data in pytorch format!') |
|
|
70 |
else: |
|
|
71 |
train_data = TestbedDataset(root='data', dataset=dataset+'_train') |
|
|
72 |
test_data = TestbedDataset(root='data', dataset=dataset+'_test') |
|
|
73 |
|
|
|
74 |
# make data PyTorch mini-batch processing ready |
|
|
75 |
train_loader = DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True) |
|
|
76 |
test_loader = DataLoader(test_data, batch_size=TEST_BATCH_SIZE, shuffle=False) |
|
|
77 |
|
|
|
78 |
# training the model |
|
|
79 |
device = torch.device(cuda_name if torch.cuda.is_available() else "cpu") |
|
|
80 |
model = modeling().to(device) |
|
|
81 |
loss_fn = nn.MSELoss() |
|
|
82 |
optimizer = torch.optim.Adam(model.parameters(), lr=LR) |
|
|
83 |
best_mse = 1000 |
|
|
84 |
best_ci = 0 |
|
|
85 |
best_epoch = -1 |
|
|
86 |
model_file_name = 'model_' + model_st + '_' + dataset + '.model' |
|
|
87 |
result_file_name = 'result_' + model_st + '_' + dataset + '.csv' |
|
|
88 |
for epoch in range(NUM_EPOCHS): |
|
|
89 |
train(model, device, train_loader, optimizer, epoch+1) |
|
|
90 |
G,P = predicting(model, device, test_loader) |
|
|
91 |
ret = [rmse(G,P),mse(G,P),pearson(G,P),spearman(G,P),ci(G,P)] |
|
|
92 |
if ret[1]<best_mse: |
|
|
93 |
torch.save(model.state_dict(), model_file_name) |
|
|
94 |
with open(result_file_name,'w') as f: |
|
|
95 |
f.write(','.join(map(str,ret))) |
|
|
96 |
best_epoch = epoch+1 |
|
|
97 |
best_mse = ret[1] |
|
|
98 |
best_ci = ret[-1] |
|
|
99 |
print('rmse improved at epoch ', best_epoch, '; best_mse,best_ci:', best_mse,best_ci,model_st,dataset) |
|
|
100 |
else: |
|
|
101 |
print(ret[1],'No improvement since epoch ', best_epoch, '; best_mse,best_ci:', best_mse,best_ci,model_st,dataset) |
|
|
102 |
|