Diff of /SGCN/SGCN.py [000000] .. [92ce40]

Switch to unified view

a b/SGCN/SGCN.py
1
import argparse
2
import sys
3
import torch
4
import time
5
import scipy.io as sio
6
import numpy as np
7
from torch.utils.data import TensorDataset, DataLoader
8
9
10
def readfile(path):
11
    print('reading file ...')
12
    data = sio.loadmat(path)
13
    x_train = []
14
    x_label = []
15
    val_data = []
16
    val_label = []
17
18
    x_train = data['train_data']
19
    x_label = data['train_label']
20
    val_data = data['test_data']
21
    val_label = data['test_label']
22
    
23
    x_train = np.array(x_train, dtype=float)
24
    val_data = np.array(val_data, dtype=float)
25
    x_label = np.array(x_label, dtype=int)
26
    val_label = np.array(val_label, dtype=int)
27
    x_train = torch.FloatTensor(x_train)
28
    val_data = torch.FloatTensor(val_data)
29
    x_label = torch.LongTensor(x_label)
30
    val_label = torch.LongTensor(val_label)
31
32
    return x_train, x_label, val_data, val_label
33
34
35
class CNNnet(torch.nn.Module):
36
      def __init__(self, node_number, batch_size, k_hop):
37
          super(CNNnet,self).__init__()
38
          self.node_number = node_number
39
          self.batch_size = batch_size
40
          self.k_hop = k_hop
41
          self.aggregate_weight = torch.nn.Parameter(torch.rand(1, 1, node_number))
42
          self.conv1 = torch.nn.Sequential(
43
              torch.nn.Conv1d(in_channels=1,
44
                              out_channels=8,
45
                              kernel_size=3,
46
                              stride=1,
47
                              padding=1),
48
              torch.nn.BatchNorm1d(8),
49
              torch.nn.ReLU(),
50
              torch.nn.MaxPool1d(kernel_size=2),
51
              #torch.nn.AvgPool1d(kernel_size=2),
52
              torch.nn.Dropout(0.2),
53
          )
54
          self.conv2 = torch.nn.Sequential(
55
              torch.nn.Conv1d(8,16,3,1,1),
56
              torch.nn.BatchNorm1d(16),
57
              torch.nn.ReLU(),
58
              torch.nn.MaxPool1d(kernel_size=2),
59
              #torch.nn.AvgPool1d(kernel_size=2),
60
              torch.nn.Dropout(0.2),
61
          )
62
          self.mlp1 = torch.nn.Sequential(
63
              torch.nn.Linear(64*16,50),
64
              torch.nn.Dropout(0.5),
65
          )
66
          self.mlp2 = torch.nn.Linear(50,2)
67
      def forward(self, x):
68
          tmp_x = x
69
          for _ in range(self.k_hop):
70
              tmp_x = torch.matmul(tmp_x, x)
71
          x = torch.matmul(self.aggregate_weight, tmp_x)
72
          x = self.conv1(x)
73
          x = self.conv2(x)
74
          x = self.mlp1(x.view(x.size(0),-1))
75
          x = self.mlp2(x)
76
          return x
77
78
def main():
79
80
    parser = argparse.ArgumentParser(description='PyTorch graph convolutional neural net for whole-graph classification')
81
    parser.add_argument('--dataset', type=str, default="dataset/AEF_V_0.mat", help='path of the dataset (default: data/data.mat)')
82
    parser.add_argument('--node_number', type=int, default=256, help='node number of graph (default: 256)')
83
    parser.add_argument('--batch_size', type=int, default=32, help='number of input size (default: 128)')
84
    parser.add_argument('--k_hop', type=int, default=4, help='times of aggregate (default: 1)')
85
86
    args = parser.parse_args()
87
88
    x_train, x_label, val_data, val_label = readfile(args.dataset)   # 'train.csv'
89
    x_train = x_train.permute(2, 0, 1)
90
    x_label = torch.squeeze(x_label, dim=1).long()
91
92
    val_data = val_data.permute(2, 0, 1)
93
    val_label = torch.squeeze(val_label, dim=1).long()    
94
95
    train_set = TensorDataset(x_train, x_label)
96
    val_set = TensorDataset(val_data, val_label)
97
98
    #batch_size = 128
99
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=0)
100
    val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=0)
101
102
    model = CNNnet(args.node_number, args.batch_size, args.k_hop)
103
    #print(model) 
104
    model  
105
    loss = torch.nn.CrossEntropyLoss()
106
    #para = list(model.parameters())
107
    #print(para)
108
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)   # optimize all cnn parameters
109
    loss_func = torch.nn.CrossEntropyLoss()
110
    best_acc = 0.0
111
112
    num_epoch = 100
113
    for epoch in range(num_epoch):
114
        epoch_start_time = time.time()
115
        train_acc = 0.0
116
        train_loss = 0.0
117
        val_acc = 0.0
118
        val_loss = 0.0
119
120
        model.train()
121
        for i, data in enumerate(train_loader):
122
            optimizer.zero_grad()
123
124
            train_pred = model(data[0]) 
125
            #print(train_pred.size())
126
            #print(data[1].size()) 
127
            batch_loss = loss(train_pred, data[1])  
128
            batch_loss.backward()  
129
            optimizer.step()  
130
131
            train_acc += np.sum(np.argmax(train_pred.cpu().data.numpy(), axis=1) == data[1].numpy())
132
            train_loss += batch_loss.item()
133
134
            
135
136
        model.eval()
137
        
138
        val_TP = 1.0
139
        val_TN = 1.0
140
        val_FN = 1.0
141
        val_FP = 1.0
142
        
143
        predict_total = []
144
        label_total = []
145
    
146
        for i, data in enumerate(val_loader):
147
            val_pred = model(data[0])
148
            batch_loss = loss(val_pred, data[1])
149
    
150
            predict_val = np.argmax(val_pred.cpu().data.numpy(), axis=1)
151
            predict_total = np.append(predict_total, predict_val)
152
            label_val = data[1].numpy()
153
            label_total = np.append(label_total, label_val)
154
            
155
            val_acc += np.sum(np.argmax(val_pred.cpu().data.numpy(), axis=1) == data[1].numpy())
156
            val_loss += batch_loss.item()
157
            
158
            
159
            
160
        val_TP = ((predict_total == 1) & (label_total == 1)).sum().item()
161
        val_TN = ((predict_total == 0) & (label_total == 0)).sum().item()
162
        val_FN = ((predict_total == 0) & (label_total == 1)).sum().item()
163
        val_FP = ((predict_total == 1) & (label_total == 0)).sum().item()
164
    
165
        val_spe = val_TN/(val_FP + val_TN + 0.001)
166
        val_rec = val_TP/(val_TP + val_FN + 0.001)
167
        test_acc = (val_TP+val_TN)/(val_FP + val_TN + val_TP + val_FN + 0.001)
168
                           
169
        val_acc = val_acc / val_set.__len__()
170
        print('%3.6f   %3.6f   %3.6f   %3.6f' % (train_acc / train_set.__len__(), train_loss, val_acc, val_loss))
171
172
        if (val_acc > best_acc):
173
            with open('save/AET_V_0.txt', 'w') as f:
174
                f.write(str(epoch) + '\t' + str(val_acc) + '\t' + str(val_spe) + '\t' + str(val_rec) + '\n')
175
            torch.save(model.state_dict(), 'save/model.pth')
176
            best_acc = val_acc
177
            
178
    for name, param in model.named_parameters():
179
        if param.requires_grad:
180
            print(param[0])
181
182
if __name__ == '__main__':
183
    main()