|
a |
|
b/gcn_model.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
# -*- coding: utf-8 -*- |
|
|
3 |
# @Time : 2021/8/8 16:20 |
|
|
4 |
# @Author : Li Xiao |
|
|
5 |
# @File : gcn_model.py |
|
|
6 |
from torch import nn |
|
|
7 |
import torch.nn.functional as F |
|
|
8 |
from layer import GraphConvolution |
|
|
9 |
|
|
|
10 |
class GCN(nn.Module): |
|
|
11 |
def __init__(self, n_in, n_hid, n_out, dropout=None): |
|
|
12 |
super(GCN, self).__init__() |
|
|
13 |
self.gc1 = GraphConvolution(n_in, n_hid) |
|
|
14 |
self.gc2 = GraphConvolution(n_hid, n_hid) |
|
|
15 |
self.dp1 = nn.Dropout(dropout) |
|
|
16 |
self.dp2 = nn.Dropout(dropout) |
|
|
17 |
#self.fc1 = nn.Linear(n_hid, n_hid) |
|
|
18 |
self.fc = nn.Linear(n_hid, n_out) |
|
|
19 |
self.dropout = dropout |
|
|
20 |
|
|
|
21 |
def forward(self, input, adj): |
|
|
22 |
x = self.gc1(input, adj) |
|
|
23 |
x = F.elu(x) |
|
|
24 |
x = self.dp1(x) |
|
|
25 |
x = self.gc2(x, adj) |
|
|
26 |
x = F.elu(x) |
|
|
27 |
x = self.dp2(x) |
|
|
28 |
|
|
|
29 |
x = self.fc(x) |
|
|
30 |
|
|
|
31 |
return x |