|
a |
|
b/models.py |
|
|
1 |
""" Componets of the model |
|
|
2 |
""" |
|
|
3 |
import torch.nn as nn |
|
|
4 |
import torch |
|
|
5 |
import torch.nn.functional as F |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
def xavier_init(m): |
|
|
9 |
if type(m) == nn.Linear: |
|
|
10 |
nn.init.xavier_normal_(m.weight) |
|
|
11 |
if m.bias is not None: |
|
|
12 |
m.bias.data.fill_(0.0) |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
class GraphConvolution(nn.Module): |
|
|
16 |
def __init__(self, in_features, out_features, bias=True): |
|
|
17 |
super().__init__() |
|
|
18 |
self.in_features = in_features |
|
|
19 |
self.out_features = out_features |
|
|
20 |
self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features)) |
|
|
21 |
if bias: |
|
|
22 |
self.bias = nn.Parameter(torch.FloatTensor(out_features)) |
|
|
23 |
nn.init.xavier_normal_(self.weight.data) |
|
|
24 |
if self.bias is not None: |
|
|
25 |
self.bias.data.fill_(0.0) |
|
|
26 |
|
|
|
27 |
def forward(self, x, adj): |
|
|
28 |
support = torch.mm(x, self.weight) |
|
|
29 |
output = torch.sparse.mm(adj, support) |
|
|
30 |
if self.bias is not None: |
|
|
31 |
return output + self.bias |
|
|
32 |
else: |
|
|
33 |
return output |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
class GCN_E(nn.Module): |
|
|
37 |
def __init__(self, in_dim, hgcn_dim, dropout): |
|
|
38 |
super().__init__() |
|
|
39 |
self.gc1 = GraphConvolution(in_dim, hgcn_dim[0]) |
|
|
40 |
self.gc2 = GraphConvolution(hgcn_dim[0], hgcn_dim[1]) |
|
|
41 |
self.gc3 = GraphConvolution(hgcn_dim[1], hgcn_dim[2]) |
|
|
42 |
self.dropout = dropout |
|
|
43 |
|
|
|
44 |
def forward(self, x, adj): |
|
|
45 |
x = self.gc1(x, adj) |
|
|
46 |
x = F.leaky_relu(x, 0.25) |
|
|
47 |
x = F.dropout(x, self.dropout, training=self.training) |
|
|
48 |
x = self.gc2(x, adj) |
|
|
49 |
x = F.leaky_relu(x, 0.25) |
|
|
50 |
x = F.dropout(x, self.dropout, training=self.training) |
|
|
51 |
x = self.gc3(x, adj) |
|
|
52 |
x = F.leaky_relu(x, 0.25) |
|
|
53 |
|
|
|
54 |
return x |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
class Classifier_1(nn.Module): |
|
|
58 |
def __init__(self, in_dim, out_dim): |
|
|
59 |
super().__init__() |
|
|
60 |
self.clf = nn.Sequential(nn.Linear(in_dim, out_dim)) |
|
|
61 |
self.clf.apply(xavier_init) |
|
|
62 |
|
|
|
63 |
def forward(self, x): |
|
|
64 |
x = self.clf(x) |
|
|
65 |
return x |
|
|
66 |
|
|
|
67 |
|
|
|
68 |
class VCDN(nn.Module): |
|
|
69 |
def __init__(self, num_view, num_cls, hvcdn_dim): |
|
|
70 |
super().__init__() |
|
|
71 |
self.num_cls = num_cls |
|
|
72 |
self.model = nn.Sequential( |
|
|
73 |
nn.Linear(pow(num_cls, num_view), hvcdn_dim), |
|
|
74 |
nn.LeakyReLU(0.25), |
|
|
75 |
nn.Linear(hvcdn_dim, num_cls) |
|
|
76 |
) |
|
|
77 |
self.model.apply(xavier_init) |
|
|
78 |
|
|
|
79 |
def forward(self, in_list): |
|
|
80 |
num_view = len(in_list) |
|
|
81 |
for i in range(num_view): |
|
|
82 |
in_list[i] = torch.sigmoid(in_list[i]) |
|
|
83 |
x = torch.reshape(torch.matmul(in_list[0].unsqueeze(-1), in_list[1].unsqueeze(1)),(-1,pow(self.num_cls,2),1)) |
|
|
84 |
for i in range(2,num_view): |
|
|
85 |
x = torch.reshape(torch.matmul(x, in_list[i].unsqueeze(1)),(-1,pow(self.num_cls,i+1),1)) |
|
|
86 |
vcdn_feat = torch.reshape(x, (-1,pow(self.num_cls,num_view))) |
|
|
87 |
output = self.model(vcdn_feat) |
|
|
88 |
|
|
|
89 |
return output |
|
|
90 |
|
|
|
91 |
|
|
|
92 |
def init_model_dict(num_view, num_class, dim_list, dim_he_list, dim_hc, gcn_dopout=0.5): |
|
|
93 |
model_dict = {} |
|
|
94 |
for i in range(num_view): |
|
|
95 |
model_dict["E{:}".format(i+1)] = GCN_E(dim_list[i], dim_he_list, gcn_dopout) |
|
|
96 |
model_dict["C{:}".format(i+1)] = Classifier_1(dim_he_list[-1], num_class) |
|
|
97 |
if num_view >= 2: |
|
|
98 |
model_dict["C"] = VCDN(num_view, num_class, dim_hc) |
|
|
99 |
return model_dict |
|
|
100 |
|
|
|
101 |
|
|
|
102 |
def init_optim(num_view, model_dict, lr_e=1e-4, lr_c=1e-4): |
|
|
103 |
optim_dict = {} |
|
|
104 |
for i in range(num_view): |
|
|
105 |
optim_dict["C{:}".format(i+1)] = torch.optim.Adam( |
|
|
106 |
list(model_dict["E{:}".format(i+1)].parameters())+list(model_dict["C{:}".format(i+1)].parameters()), |
|
|
107 |
lr=lr_e) |
|
|
108 |
if num_view >= 2: |
|
|
109 |
optim_dict["C"] = torch.optim.Adam(model_dict["C"].parameters(), lr=lr_c) |
|
|
110 |
return optim_dict |