|
a |
|
b/SGCN/SGCN.py |
|
|
1 |
import argparse |
|
|
2 |
import sys |
|
|
3 |
import torch |
|
|
4 |
import time |
|
|
5 |
import scipy.io as sio |
|
|
6 |
import numpy as np |
|
|
7 |
from torch.utils.data import TensorDataset, DataLoader |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
def readfile(path): |
|
|
11 |
print('reading file ...') |
|
|
12 |
data = sio.loadmat(path) |
|
|
13 |
x_train = [] |
|
|
14 |
x_label = [] |
|
|
15 |
val_data = [] |
|
|
16 |
val_label = [] |
|
|
17 |
|
|
|
18 |
x_train = data['train_data'] |
|
|
19 |
x_label = data['train_label'] |
|
|
20 |
val_data = data['test_data'] |
|
|
21 |
val_label = data['test_label'] |
|
|
22 |
|
|
|
23 |
x_train = np.array(x_train, dtype=float) |
|
|
24 |
val_data = np.array(val_data, dtype=float) |
|
|
25 |
x_label = np.array(x_label, dtype=int) |
|
|
26 |
val_label = np.array(val_label, dtype=int) |
|
|
27 |
x_train = torch.FloatTensor(x_train) |
|
|
28 |
val_data = torch.FloatTensor(val_data) |
|
|
29 |
x_label = torch.LongTensor(x_label) |
|
|
30 |
val_label = torch.LongTensor(val_label) |
|
|
31 |
|
|
|
32 |
return x_train, x_label, val_data, val_label |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
class CNNnet(torch.nn.Module): |
|
|
36 |
def __init__(self, node_number, batch_size, k_hop): |
|
|
37 |
super(CNNnet,self).__init__() |
|
|
38 |
self.node_number = node_number |
|
|
39 |
self.batch_size = batch_size |
|
|
40 |
self.k_hop = k_hop |
|
|
41 |
self.aggregate_weight = torch.nn.Parameter(torch.rand(1, 1, node_number)) |
|
|
42 |
self.conv1 = torch.nn.Sequential( |
|
|
43 |
torch.nn.Conv1d(in_channels=1, |
|
|
44 |
out_channels=8, |
|
|
45 |
kernel_size=3, |
|
|
46 |
stride=1, |
|
|
47 |
padding=1), |
|
|
48 |
torch.nn.BatchNorm1d(8), |
|
|
49 |
torch.nn.ReLU(), |
|
|
50 |
torch.nn.MaxPool1d(kernel_size=2), |
|
|
51 |
#torch.nn.AvgPool1d(kernel_size=2), |
|
|
52 |
torch.nn.Dropout(0.2), |
|
|
53 |
) |
|
|
54 |
self.conv2 = torch.nn.Sequential( |
|
|
55 |
torch.nn.Conv1d(8,16,3,1,1), |
|
|
56 |
torch.nn.BatchNorm1d(16), |
|
|
57 |
torch.nn.ReLU(), |
|
|
58 |
torch.nn.MaxPool1d(kernel_size=2), |
|
|
59 |
#torch.nn.AvgPool1d(kernel_size=2), |
|
|
60 |
torch.nn.Dropout(0.2), |
|
|
61 |
) |
|
|
62 |
self.mlp1 = torch.nn.Sequential( |
|
|
63 |
torch.nn.Linear(64*16,50), |
|
|
64 |
torch.nn.Dropout(0.5), |
|
|
65 |
) |
|
|
66 |
self.mlp2 = torch.nn.Linear(50,2) |
|
|
67 |
def forward(self, x): |
|
|
68 |
tmp_x = x |
|
|
69 |
for _ in range(self.k_hop): |
|
|
70 |
tmp_x = torch.matmul(tmp_x, x) |
|
|
71 |
x = torch.matmul(self.aggregate_weight, tmp_x) |
|
|
72 |
x = self.conv1(x) |
|
|
73 |
x = self.conv2(x) |
|
|
74 |
x = self.mlp1(x.view(x.size(0),-1)) |
|
|
75 |
x = self.mlp2(x) |
|
|
76 |
return x |
|
|
77 |
|
|
|
78 |
def main(): |
|
|
79 |
|
|
|
80 |
parser = argparse.ArgumentParser(description='PyTorch graph convolutional neural net for whole-graph classification') |
|
|
81 |
parser.add_argument('--dataset', type=str, default="dataset/AEF_V_0.mat", help='path of the dataset (default: data/data.mat)') |
|
|
82 |
parser.add_argument('--node_number', type=int, default=256, help='node number of graph (default: 256)') |
|
|
83 |
parser.add_argument('--batch_size', type=int, default=32, help='number of input size (default: 128)') |
|
|
84 |
parser.add_argument('--k_hop', type=int, default=4, help='times of aggregate (default: 1)') |
|
|
85 |
|
|
|
86 |
args = parser.parse_args() |
|
|
87 |
|
|
|
88 |
x_train, x_label, val_data, val_label = readfile(args.dataset) # 'train.csv' |
|
|
89 |
x_train = x_train.permute(2, 0, 1) |
|
|
90 |
x_label = torch.squeeze(x_label, dim=1).long() |
|
|
91 |
|
|
|
92 |
val_data = val_data.permute(2, 0, 1) |
|
|
93 |
val_label = torch.squeeze(val_label, dim=1).long() |
|
|
94 |
|
|
|
95 |
train_set = TensorDataset(x_train, x_label) |
|
|
96 |
val_set = TensorDataset(val_data, val_label) |
|
|
97 |
|
|
|
98 |
#batch_size = 128 |
|
|
99 |
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=0) |
|
|
100 |
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=0) |
|
|
101 |
|
|
|
102 |
model = CNNnet(args.node_number, args.batch_size, args.k_hop) |
|
|
103 |
#print(model) |
|
|
104 |
model |
|
|
105 |
loss = torch.nn.CrossEntropyLoss() |
|
|
106 |
#para = list(model.parameters()) |
|
|
107 |
#print(para) |
|
|
108 |
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # optimize all cnn parameters |
|
|
109 |
loss_func = torch.nn.CrossEntropyLoss() |
|
|
110 |
best_acc = 0.0 |
|
|
111 |
|
|
|
112 |
num_epoch = 100 |
|
|
113 |
for epoch in range(num_epoch): |
|
|
114 |
epoch_start_time = time.time() |
|
|
115 |
train_acc = 0.0 |
|
|
116 |
train_loss = 0.0 |
|
|
117 |
val_acc = 0.0 |
|
|
118 |
val_loss = 0.0 |
|
|
119 |
|
|
|
120 |
model.train() |
|
|
121 |
for i, data in enumerate(train_loader): |
|
|
122 |
optimizer.zero_grad() |
|
|
123 |
|
|
|
124 |
train_pred = model(data[0]) |
|
|
125 |
#print(train_pred.size()) |
|
|
126 |
#print(data[1].size()) |
|
|
127 |
batch_loss = loss(train_pred, data[1]) |
|
|
128 |
batch_loss.backward() |
|
|
129 |
optimizer.step() |
|
|
130 |
|
|
|
131 |
train_acc += np.sum(np.argmax(train_pred.cpu().data.numpy(), axis=1) == data[1].numpy()) |
|
|
132 |
train_loss += batch_loss.item() |
|
|
133 |
|
|
|
134 |
|
|
|
135 |
|
|
|
136 |
model.eval() |
|
|
137 |
|
|
|
138 |
val_TP = 1.0 |
|
|
139 |
val_TN = 1.0 |
|
|
140 |
val_FN = 1.0 |
|
|
141 |
val_FP = 1.0 |
|
|
142 |
|
|
|
143 |
predict_total = [] |
|
|
144 |
label_total = [] |
|
|
145 |
|
|
|
146 |
for i, data in enumerate(val_loader): |
|
|
147 |
val_pred = model(data[0]) |
|
|
148 |
batch_loss = loss(val_pred, data[1]) |
|
|
149 |
|
|
|
150 |
predict_val = np.argmax(val_pred.cpu().data.numpy(), axis=1) |
|
|
151 |
predict_total = np.append(predict_total, predict_val) |
|
|
152 |
label_val = data[1].numpy() |
|
|
153 |
label_total = np.append(label_total, label_val) |
|
|
154 |
|
|
|
155 |
val_acc += np.sum(np.argmax(val_pred.cpu().data.numpy(), axis=1) == data[1].numpy()) |
|
|
156 |
val_loss += batch_loss.item() |
|
|
157 |
|
|
|
158 |
|
|
|
159 |
|
|
|
160 |
val_TP = ((predict_total == 1) & (label_total == 1)).sum().item() |
|
|
161 |
val_TN = ((predict_total == 0) & (label_total == 0)).sum().item() |
|
|
162 |
val_FN = ((predict_total == 0) & (label_total == 1)).sum().item() |
|
|
163 |
val_FP = ((predict_total == 1) & (label_total == 0)).sum().item() |
|
|
164 |
|
|
|
165 |
val_spe = val_TN/(val_FP + val_TN + 0.001) |
|
|
166 |
val_rec = val_TP/(val_TP + val_FN + 0.001) |
|
|
167 |
test_acc = (val_TP+val_TN)/(val_FP + val_TN + val_TP + val_FN + 0.001) |
|
|
168 |
|
|
|
169 |
val_acc = val_acc / val_set.__len__() |
|
|
170 |
print('%3.6f %3.6f %3.6f %3.6f' % (train_acc / train_set.__len__(), train_loss, val_acc, val_loss)) |
|
|
171 |
|
|
|
172 |
if (val_acc > best_acc): |
|
|
173 |
with open('save/AET_V_0.txt', 'w') as f: |
|
|
174 |
f.write(str(epoch) + '\t' + str(val_acc) + '\t' + str(val_spe) + '\t' + str(val_rec) + '\n') |
|
|
175 |
torch.save(model.state_dict(), 'save/model.pth') |
|
|
176 |
best_acc = val_acc |
|
|
177 |
|
|
|
178 |
for name, param in model.named_parameters(): |
|
|
179 |
if param.requires_grad: |
|
|
180 |
print(param[0]) |
|
|
181 |
|
|
|
182 |
if __name__ == '__main__': |
|
|
183 |
main() |