--- a +++ b/SGCN/SGCN.py @@ -0,0 +1,183 @@ +import argparse +import sys +import torch +import time +import scipy.io as sio +import numpy as np +from torch.utils.data import TensorDataset, DataLoader + + +def readfile(path): + print('reading file ...') + data = sio.loadmat(path) + x_train = [] + x_label = [] + val_data = [] + val_label = [] + + x_train = data['train_data'] + x_label = data['train_label'] + val_data = data['test_data'] + val_label = data['test_label'] + + x_train = np.array(x_train, dtype=float) + val_data = np.array(val_data, dtype=float) + x_label = np.array(x_label, dtype=int) + val_label = np.array(val_label, dtype=int) + x_train = torch.FloatTensor(x_train) + val_data = torch.FloatTensor(val_data) + x_label = torch.LongTensor(x_label) + val_label = torch.LongTensor(val_label) + + return x_train, x_label, val_data, val_label + + +class CNNnet(torch.nn.Module): + def __init__(self, node_number, batch_size, k_hop): + super(CNNnet,self).__init__() + self.node_number = node_number + self.batch_size = batch_size + self.k_hop = k_hop + self.aggregate_weight = torch.nn.Parameter(torch.rand(1, 1, node_number)) + self.conv1 = torch.nn.Sequential( + torch.nn.Conv1d(in_channels=1, + out_channels=8, + kernel_size=3, + stride=1, + padding=1), + torch.nn.BatchNorm1d(8), + torch.nn.ReLU(), + torch.nn.MaxPool1d(kernel_size=2), + #torch.nn.AvgPool1d(kernel_size=2), + torch.nn.Dropout(0.2), + ) + self.conv2 = torch.nn.Sequential( + torch.nn.Conv1d(8,16,3,1,1), + torch.nn.BatchNorm1d(16), + torch.nn.ReLU(), + torch.nn.MaxPool1d(kernel_size=2), + #torch.nn.AvgPool1d(kernel_size=2), + torch.nn.Dropout(0.2), + ) + self.mlp1 = torch.nn.Sequential( + torch.nn.Linear(64*16,50), + torch.nn.Dropout(0.5), + ) + self.mlp2 = torch.nn.Linear(50,2) + def forward(self, x): + tmp_x = x + for _ in range(self.k_hop): + tmp_x = torch.matmul(tmp_x, x) + x = torch.matmul(self.aggregate_weight, tmp_x) + x = self.conv1(x) + x = self.conv2(x) + x = self.mlp1(x.view(x.size(0),-1)) + x = self.mlp2(x) + return x + +def main(): + + parser = argparse.ArgumentParser(description='PyTorch graph convolutional neural net for whole-graph classification') + parser.add_argument('--dataset', type=str, default="dataset/AEF_V_0.mat", help='path of the dataset (default: data/data.mat)') + parser.add_argument('--node_number', type=int, default=256, help='node number of graph (default: 256)') + parser.add_argument('--batch_size', type=int, default=32, help='number of input size (default: 128)') + parser.add_argument('--k_hop', type=int, default=4, help='times of aggregate (default: 1)') + + args = parser.parse_args() + + x_train, x_label, val_data, val_label = readfile(args.dataset) # 'train.csv' + x_train = x_train.permute(2, 0, 1) + x_label = torch.squeeze(x_label, dim=1).long() + + val_data = val_data.permute(2, 0, 1) + val_label = torch.squeeze(val_label, dim=1).long() + + train_set = TensorDataset(x_train, x_label) + val_set = TensorDataset(val_data, val_label) + + #batch_size = 128 + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=0) + val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=0) + + model = CNNnet(args.node_number, args.batch_size, args.k_hop) + #print(model) + model + loss = torch.nn.CrossEntropyLoss() + #para = list(model.parameters()) + #print(para) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # optimize all cnn parameters + loss_func = torch.nn.CrossEntropyLoss() + best_acc = 0.0 + + num_epoch = 100 + for epoch in range(num_epoch): + epoch_start_time = time.time() + train_acc = 0.0 + train_loss = 0.0 + val_acc = 0.0 + val_loss = 0.0 + + model.train() + for i, data in enumerate(train_loader): + optimizer.zero_grad() + + train_pred = model(data[0]) + #print(train_pred.size()) + #print(data[1].size()) + batch_loss = loss(train_pred, data[1]) + batch_loss.backward() + optimizer.step() + + train_acc += np.sum(np.argmax(train_pred.cpu().data.numpy(), axis=1) == data[1].numpy()) + train_loss += batch_loss.item() + + + + model.eval() + + val_TP = 1.0 + val_TN = 1.0 + val_FN = 1.0 + val_FP = 1.0 + + predict_total = [] + label_total = [] + + for i, data in enumerate(val_loader): + val_pred = model(data[0]) + batch_loss = loss(val_pred, data[1]) + + predict_val = np.argmax(val_pred.cpu().data.numpy(), axis=1) + predict_total = np.append(predict_total, predict_val) + label_val = data[1].numpy() + label_total = np.append(label_total, label_val) + + val_acc += np.sum(np.argmax(val_pred.cpu().data.numpy(), axis=1) == data[1].numpy()) + val_loss += batch_loss.item() + + + + val_TP = ((predict_total == 1) & (label_total == 1)).sum().item() + val_TN = ((predict_total == 0) & (label_total == 0)).sum().item() + val_FN = ((predict_total == 0) & (label_total == 1)).sum().item() + val_FP = ((predict_total == 1) & (label_total == 0)).sum().item() + + val_spe = val_TN/(val_FP + val_TN + 0.001) + val_rec = val_TP/(val_TP + val_FN + 0.001) + test_acc = (val_TP+val_TN)/(val_FP + val_TN + val_TP + val_FN + 0.001) + + val_acc = val_acc / val_set.__len__() + print('%3.6f %3.6f %3.6f %3.6f' % (train_acc / train_set.__len__(), train_loss, val_acc, val_loss)) + + if (val_acc > best_acc): + with open('save/AET_V_0.txt', 'w') as f: + f.write(str(epoch) + '\t' + str(val_acc) + '\t' + str(val_spe) + '\t' + str(val_rec) + '\n') + torch.save(model.state_dict(), 'save/model.pth') + best_acc = val_acc + + for name, param in model.named_parameters(): + if param.requires_grad: + print(param[0]) + +if __name__ == '__main__': + main()