|
a |
|
b/GCN_run.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
# -*- coding: utf-8 -*- |
|
|
3 |
# @Time : 2021/8/8 16:43 |
|
|
4 |
# @Author : Li Xiao |
|
|
5 |
# @File : GCN_run.py |
|
|
6 |
import numpy as np |
|
|
7 |
import pandas as pd |
|
|
8 |
import argparse |
|
|
9 |
import glob |
|
|
10 |
import os |
|
|
11 |
from sklearn.model_selection import StratifiedKFold |
|
|
12 |
from sklearn.metrics import f1_score |
|
|
13 |
import torch |
|
|
14 |
import torch.nn.functional as F |
|
|
15 |
from gcn_model import GCN |
|
|
16 |
from utils import load_data |
|
|
17 |
from utils import accuracy |
|
|
18 |
|
|
|
19 |
def setup_seed(seed): |
|
|
20 |
torch.manual_seed(seed) |
|
|
21 |
np.random.seed(seed) |
|
|
22 |
|
|
|
23 |
def train(epoch, optimizer, features, adj, labels, idx_train): |
|
|
24 |
''' |
|
|
25 |
:param epoch: training epochs |
|
|
26 |
:param optimizer: training optimizer, Adam optimizer |
|
|
27 |
:param features: the omics features |
|
|
28 |
:param adj: the laplace adjacency matrix |
|
|
29 |
:param labels: sample labels |
|
|
30 |
:param idx_train: the index of trained samples |
|
|
31 |
''' |
|
|
32 |
labels.to(device) |
|
|
33 |
|
|
|
34 |
GCN_model.train() |
|
|
35 |
optimizer.zero_grad() |
|
|
36 |
output = GCN_model(features, adj) |
|
|
37 |
loss_train = F.cross_entropy(output[idx_train], labels[idx_train]) |
|
|
38 |
acc_train = accuracy(output[idx_train], labels[idx_train]) |
|
|
39 |
loss_train.backward() |
|
|
40 |
optimizer.step() |
|
|
41 |
if (epoch+1) % 10 ==0: |
|
|
42 |
print('Epoch: %.2f | loss train: %.4f | acc train: %.4f' %(epoch+1, loss_train.item(), acc_train.item())) |
|
|
43 |
return loss_train.data.item() |
|
|
44 |
|
|
|
45 |
def test(features, adj, labels, idx_test): |
|
|
46 |
''' |
|
|
47 |
:param features: the omics features |
|
|
48 |
:param adj: the laplace adjacency matrix |
|
|
49 |
:param labels: sample labels |
|
|
50 |
:param idx_test: the index of tested samples |
|
|
51 |
''' |
|
|
52 |
GCN_model.eval() |
|
|
53 |
output = GCN_model(features, adj) |
|
|
54 |
loss_test = F.cross_entropy(output[idx_test], labels[idx_test]) |
|
|
55 |
|
|
|
56 |
#calculate the accuracy |
|
|
57 |
acc_test = accuracy(output[idx_test], labels[idx_test]) |
|
|
58 |
|
|
|
59 |
#output is the one-hot label |
|
|
60 |
ot = output[idx_test].detach().cpu().numpy() |
|
|
61 |
#change one-hot label to digit label |
|
|
62 |
ot = np.argmax(ot, axis=1) |
|
|
63 |
#original label |
|
|
64 |
lb = labels[idx_test].detach().cpu().numpy() |
|
|
65 |
print('predict label: ', ot) |
|
|
66 |
print('original label: ', lb) |
|
|
67 |
|
|
|
68 |
#calculate the f1 score |
|
|
69 |
f = f1_score(ot, lb, average='weighted') |
|
|
70 |
|
|
|
71 |
print("Test set results:", |
|
|
72 |
"loss= {:.4f}".format(loss_test.item()), |
|
|
73 |
"accuracy= {:.4f}".format(acc_test.item())) |
|
|
74 |
|
|
|
75 |
#return accuracy and f1 score |
|
|
76 |
return acc_test.item(), f |
|
|
77 |
|
|
|
78 |
def predict(features, adj, sample, idx): |
|
|
79 |
''' |
|
|
80 |
:param features: the omics features |
|
|
81 |
:param adj: the laplace adjacency matrix |
|
|
82 |
:param sample: all sample names |
|
|
83 |
:param idx: the index of predict samples |
|
|
84 |
:return: |
|
|
85 |
''' |
|
|
86 |
GCN_model.eval() |
|
|
87 |
output = GCN_model(features, adj) |
|
|
88 |
predict_label = output.detach().cpu().numpy() |
|
|
89 |
predict_label = np.argmax(predict_label, axis=1).tolist() |
|
|
90 |
#print(predict_label) |
|
|
91 |
|
|
|
92 |
res_data = pd.DataFrame({'Sample':sample, 'predict_label':predict_label}) |
|
|
93 |
res_data = res_data.iloc[idx,:] |
|
|
94 |
#print(res_data) |
|
|
95 |
|
|
|
96 |
res_data.to_csv('result/GCN_predicted_data.csv', header=True, index=False) |
|
|
97 |
|
|
|
98 |
if __name__ == '__main__': |
|
|
99 |
parser = argparse.ArgumentParser() |
|
|
100 |
parser.add_argument('--featuredata', '-fd', type=str, required=True, help='The vector feature file.') |
|
|
101 |
parser.add_argument('--adjdata', '-ad', type=str, required=True, help='The adjacency matrix file.') |
|
|
102 |
parser.add_argument('--labeldata', '-ld', type=str, required=True, help='The sample label file.') |
|
|
103 |
parser.add_argument('--testsample', '-ts', type=str, help='Test sample names file.') |
|
|
104 |
parser.add_argument('--mode', '-m', type=int, choices=[0,1], default=0, |
|
|
105 |
help='mode 0: 10-fold cross validation; mode 1: train and test a model.') |
|
|
106 |
parser.add_argument('--seed', '-s', type=int, default=0, help='Random seed, default=0.') |
|
|
107 |
parser.add_argument('--device', '-d', type=str, choices=['cpu', 'gpu'], default='cpu', |
|
|
108 |
help='Training on cpu or gpu, default: cpu.') |
|
|
109 |
parser.add_argument('--epochs', '-e', type=int, default=150, help='Training epochs, default: 150.') |
|
|
110 |
parser.add_argument('--learningrate', '-lr', type=float, default=0.001, help='Learning rate, default: 0.001.') |
|
|
111 |
parser.add_argument('--weight_decay', '-w', type=float, default=0.01, |
|
|
112 |
help='Weight decay (L2 loss on parameters), methods to avoid overfitting, default: 0.01') |
|
|
113 |
parser.add_argument('--hidden', '-hd',type=int, default=64, help='Hidden layer dimension, default: 64.') |
|
|
114 |
parser.add_argument('--dropout', '-dp', type=float, default=0.5, help='Dropout rate, methods to avoid overfitting, default: 0.5.') |
|
|
115 |
parser.add_argument('--threshold', '-t', type=float, default=0.005, help='Threshold to filter edges, default: 0.005') |
|
|
116 |
parser.add_argument('--nclass', '-nc', type=int, default=4, help='Number of classes, default: 4') |
|
|
117 |
parser.add_argument('--patience', '-p', type=int, default=20, help='Patience') |
|
|
118 |
args = parser.parse_args() |
|
|
119 |
|
|
|
120 |
# Check whether GPUs are available |
|
|
121 |
device = torch.device('cpu') |
|
|
122 |
if args.device == 'gpu': |
|
|
123 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
124 |
|
|
|
125 |
# set random seed |
|
|
126 |
setup_seed(args.seed) |
|
|
127 |
|
|
|
128 |
# load input files |
|
|
129 |
adj, data, label = load_data(args.adjdata, args.featuredata, args.labeldata, args.threshold) |
|
|
130 |
|
|
|
131 |
# change dataframe to Tensor |
|
|
132 |
adj = torch.tensor(adj, dtype=torch.float, device=device) |
|
|
133 |
features = torch.tensor(data.iloc[:, 1:].values, dtype=torch.float, device=device) |
|
|
134 |
labels = torch.tensor(label.iloc[:, 1].values, dtype=torch.long, device=device) |
|
|
135 |
|
|
|
136 |
print('Begin training model...') |
|
|
137 |
|
|
|
138 |
# 10-fold cross validation |
|
|
139 |
if args.mode == 0: |
|
|
140 |
skf = StratifiedKFold(n_splits=10, shuffle=True) |
|
|
141 |
|
|
|
142 |
acc_res, f1_res = [], [] #record accuracy and f1 score |
|
|
143 |
|
|
|
144 |
# split train and test data |
|
|
145 |
for idx_train, idx_test in skf.split(data.iloc[:, 1:], label.iloc[:, 1]): |
|
|
146 |
# initialize a model |
|
|
147 |
GCN_model = GCN(n_in=features.shape[1], n_hid=args.hidden, n_out=args.nclass, dropout=args.dropout) |
|
|
148 |
GCN_model.to(device) |
|
|
149 |
|
|
|
150 |
# define the optimizer |
|
|
151 |
optimizer = torch.optim.Adam(GCN_model.parameters(), lr=args.learningrate, weight_decay=args.weight_decay) |
|
|
152 |
|
|
|
153 |
idx_train, idx_test= torch.tensor(idx_train, dtype=torch.long, device=device), torch.tensor(idx_test, dtype=torch.long, device=device) |
|
|
154 |
for epoch in range(args.epochs): |
|
|
155 |
train(epoch, optimizer, features, adj, labels, idx_train) |
|
|
156 |
|
|
|
157 |
# calculate the accuracy and f1 score |
|
|
158 |
ac, f1= test(features, adj, labels, idx_test) |
|
|
159 |
acc_res.append(ac) |
|
|
160 |
f1_res.append(f1) |
|
|
161 |
print('10-fold Acc(%.4f, %.4f) F1(%.4f, %.4f)' % (np.mean(acc_res), np.std(acc_res), np.mean(f1_res), np.std(f1_res))) |
|
|
162 |
#predict(features, adj, data['Sample'].tolist(), data.index.tolist()) |
|
|
163 |
|
|
|
164 |
elif args.mode == 1: |
|
|
165 |
# load test samples |
|
|
166 |
test_sample_df = pd.read_csv(args.testsample, header=0, index_col=None) |
|
|
167 |
test_sample = test_sample_df.iloc[:, 0].tolist() |
|
|
168 |
all_sample = data['Sample'].tolist() |
|
|
169 |
train_sample = list(set(all_sample)-set(test_sample)) |
|
|
170 |
|
|
|
171 |
#get index of train samples and test samples |
|
|
172 |
train_idx = data[data['Sample'].isin(train_sample)].index.tolist() |
|
|
173 |
test_idx = data[data['Sample'].isin(test_sample)].index.tolist() |
|
|
174 |
|
|
|
175 |
GCN_model = GCN(n_in=features.shape[1], n_hid=args.hidden, n_out=args.nclass, dropout=args.dropout) |
|
|
176 |
GCN_model.to(device) |
|
|
177 |
optimizer = torch.optim.Adam(GCN_model.parameters(), lr=args.learningrate, weight_decay=args.weight_decay) |
|
|
178 |
idx_train, idx_test = torch.tensor(train_idx, dtype=torch.long, device=device), torch.tensor(test_idx, dtype=torch.long, device=device) |
|
|
179 |
|
|
|
180 |
''' |
|
|
181 |
save a best model (with the minimum loss value) |
|
|
182 |
if the loss didn't decrease in N epochs,stop the train process. |
|
|
183 |
N can be set by args.patience |
|
|
184 |
''' |
|
|
185 |
loss_values = [] #record the loss value of each epoch |
|
|
186 |
# record the times with no loss decrease, record the best epoch |
|
|
187 |
bad_counter, best_epoch = 0, 0 |
|
|
188 |
best = 1000 #record the lowest loss value |
|
|
189 |
for epoch in range(args.epochs): |
|
|
190 |
loss_values.append(train(epoch, optimizer, features, adj, labels, idx_train)) |
|
|
191 |
if loss_values[-1] < best: |
|
|
192 |
best = loss_values[-1] |
|
|
193 |
best_epoch = epoch |
|
|
194 |
bad_counter = 0 |
|
|
195 |
else: |
|
|
196 |
bad_counter += 1 #In this epoch, the loss value didn't decrease |
|
|
197 |
|
|
|
198 |
if bad_counter == args.patience: |
|
|
199 |
break |
|
|
200 |
|
|
|
201 |
#save model of this epoch |
|
|
202 |
torch.save(GCN_model.state_dict(), 'model/GCN/{}.pkl'.format(epoch)) |
|
|
203 |
|
|
|
204 |
#reserve the best model, delete other models |
|
|
205 |
files = glob.glob('model/GCN/*.pkl') |
|
|
206 |
for file in files: |
|
|
207 |
name = file.split('\\')[1] |
|
|
208 |
epoch_nb = int(name.split('.')[0]) |
|
|
209 |
#print(file, name, epoch_nb) |
|
|
210 |
if epoch_nb != best_epoch: |
|
|
211 |
os.remove(file) |
|
|
212 |
|
|
|
213 |
print('Training finished.') |
|
|
214 |
print('The best epoch model is ',best_epoch) |
|
|
215 |
GCN_model.load_state_dict(torch.load('model/GCN/{}.pkl'.format(best_epoch))) |
|
|
216 |
predict(features, adj, all_sample, test_idx) |
|
|
217 |
|
|
|
218 |
print('Finished!') |