|
a |
|
b/layer.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
# -*- coding: utf-8 -*- |
|
|
3 |
# @Time : 2021/8/8 16:19 |
|
|
4 |
# @Author : Li Xiao |
|
|
5 |
# @File : layer.py |
|
|
6 |
import torch |
|
|
7 |
import math |
|
|
8 |
from torch import nn |
|
|
9 |
from torch.nn.parameter import Parameter |
|
|
10 |
|
|
|
11 |
class GraphConvolution(nn.Module): |
|
|
12 |
def __init__(self, infeas, outfeas, bias=True): |
|
|
13 |
super(GraphConvolution,self).__init__() |
|
|
14 |
self.in_features = infeas |
|
|
15 |
self.out_features = outfeas |
|
|
16 |
self.weight = Parameter(torch.FloatTensor(infeas, outfeas)) |
|
|
17 |
if bias: |
|
|
18 |
self.bias = Parameter(torch.FloatTensor(outfeas)) |
|
|
19 |
else: |
|
|
20 |
self.register_parameter('bias', None) |
|
|
21 |
self.reset_parameters() |
|
|
22 |
|
|
|
23 |
def reset_parameters(self): |
|
|
24 |
|
|
|
25 |
stdv = 1. / math.sqrt(self.weight.size(1)) |
|
|
26 |
self.weight.data.uniform_(-stdv, stdv) |
|
|
27 |
if self.bias is not None: |
|
|
28 |
self.bias.data.uniform_(-stdv,stdv) |
|
|
29 |
''' |
|
|
30 |
for name, param in GraphConvolution.named_parameters(self): |
|
|
31 |
if 'weight' in name: |
|
|
32 |
#torch.nn.init.constant_(param, val=0.1) |
|
|
33 |
torch.nn.init.normal_(param, mean=0, std=0.1) |
|
|
34 |
if 'bias' in name: |
|
|
35 |
torch.nn.init.constant_(param, val=0) |
|
|
36 |
''' |
|
|
37 |
|
|
|
38 |
def forward(self, x, adj): |
|
|
39 |
x1 = torch.mm(x, self.weight) |
|
|
40 |
output = torch.mm(adj, x1) |
|
|
41 |
if self.bias is not None: |
|
|
42 |
return output + self.bias |
|
|
43 |
else: |
|
|
44 |
return output |