Diff of /models.py [000000] .. [ab33d2]

Switch to unified view

a b/models.py
1
""" Componets of the model
2
"""
3
import torch.nn as nn
4
import torch
5
import torch.nn.functional as F
6
7
8
def xavier_init(m):
9
    if type(m) == nn.Linear:
10
        nn.init.xavier_normal_(m.weight)
11
        if m.bias is not None:
12
           m.bias.data.fill_(0.0)
13
           
14
15
class GraphConvolution(nn.Module):
16
    def __init__(self, in_features, out_features, bias=True):
17
        super().__init__()
18
        self.in_features = in_features
19
        self.out_features = out_features
20
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
21
        if bias:
22
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
23
        nn.init.xavier_normal_(self.weight.data)
24
        if self.bias is not None:
25
            self.bias.data.fill_(0.0)
26
    
27
    def forward(self, x, adj):
28
        support = torch.mm(x, self.weight)
29
        output = torch.sparse.mm(adj, support)
30
        if self.bias is not None:
31
            return output + self.bias
32
        else:
33
            return output
34
    
35
36
class GCN_E(nn.Module):
37
    def __init__(self, in_dim, hgcn_dim, dropout):
38
        super().__init__()
39
        self.gc1 = GraphConvolution(in_dim, hgcn_dim[0])
40
        self.gc2 = GraphConvolution(hgcn_dim[0], hgcn_dim[1])
41
        self.gc3 = GraphConvolution(hgcn_dim[1], hgcn_dim[2])
42
        self.dropout = dropout
43
44
    def forward(self, x, adj):
45
        x = self.gc1(x, adj)
46
        x = F.leaky_relu(x, 0.25)
47
        x = F.dropout(x, self.dropout, training=self.training)
48
        x = self.gc2(x, adj)
49
        x = F.leaky_relu(x, 0.25)
50
        x = F.dropout(x, self.dropout, training=self.training)
51
        x = self.gc3(x, adj)
52
        x = F.leaky_relu(x, 0.25)
53
        
54
        return x
55
56
57
class Classifier_1(nn.Module):
58
    def __init__(self, in_dim, out_dim):
59
        super().__init__()
60
        self.clf = nn.Sequential(nn.Linear(in_dim, out_dim))
61
        self.clf.apply(xavier_init)
62
63
    def forward(self, x):
64
        x = self.clf(x)
65
        return x
66
67
68
class VCDN(nn.Module):
69
    def __init__(self, num_view, num_cls, hvcdn_dim):
70
        super().__init__()
71
        self.num_cls = num_cls
72
        self.model = nn.Sequential(
73
            nn.Linear(pow(num_cls, num_view), hvcdn_dim),
74
            nn.LeakyReLU(0.25),
75
            nn.Linear(hvcdn_dim, num_cls)
76
        )
77
        self.model.apply(xavier_init)
78
        
79
    def forward(self, in_list):
80
        num_view = len(in_list)
81
        for i in range(num_view):
82
            in_list[i] = torch.sigmoid(in_list[i])
83
        x = torch.reshape(torch.matmul(in_list[0].unsqueeze(-1), in_list[1].unsqueeze(1)),(-1,pow(self.num_cls,2),1))
84
        for i in range(2,num_view):
85
            x = torch.reshape(torch.matmul(x, in_list[i].unsqueeze(1)),(-1,pow(self.num_cls,i+1),1))
86
        vcdn_feat = torch.reshape(x, (-1,pow(self.num_cls,num_view)))
87
        output = self.model(vcdn_feat)
88
89
        return output
90
91
    
92
def init_model_dict(num_view, num_class, dim_list, dim_he_list, dim_hc, gcn_dopout=0.5):
93
    model_dict = {}
94
    for i in range(num_view):
95
        model_dict["E{:}".format(i+1)] = GCN_E(dim_list[i], dim_he_list, gcn_dopout)
96
        model_dict["C{:}".format(i+1)] = Classifier_1(dim_he_list[-1], num_class)
97
    if num_view >= 2:
98
        model_dict["C"] = VCDN(num_view, num_class, dim_hc)
99
    return model_dict
100
101
102
def init_optim(num_view, model_dict, lr_e=1e-4, lr_c=1e-4):
103
    optim_dict = {}
104
    for i in range(num_view):
105
        optim_dict["C{:}".format(i+1)] = torch.optim.Adam(
106
                list(model_dict["E{:}".format(i+1)].parameters())+list(model_dict["C{:}".format(i+1)].parameters()), 
107
                lr=lr_e)
108
    if num_view >= 2:
109
        optim_dict["C"] = torch.optim.Adam(model_dict["C"].parameters(), lr=lr_c)
110
    return optim_dict