Diff of /kgwas/model.py [000000] .. [8790ab]

Switch to unified view

a b/kgwas/model.py
1
from torch_geometric.nn import Linear, SAGEConv, GCNConv, SGConv, Sequential, to_hetero, HeteroConv
2
import torch
3
import torch.nn.functional as F
4
import torch.optim as optim
5
import torch.nn as nn
6
7
from .conv import GATConv
8
9
10
class SimpleMLP(nn.Module):
11
    def __init__(self, input_dim, hidden_dim, output_dim):
12
        super(SimpleMLP, self).__init__()
13
        self.FC_hidden = nn.Linear(input_dim, hidden_dim)
14
        self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
15
        self.FC_output = nn.Linear(hidden_dim, output_dim)
16
        self.ReLU = nn.ReLU()       
17
                
18
    def forward(self, x):
19
        h     = self.ReLU(self.FC_hidden(x))
20
        h     = self.ReLU(self.FC_hidden2(h))
21
        x_hat = self.FC_output(h)
22
        return x_hat
23
24
class HeteroGNN(torch.nn.Module):
25
    def __init__(self, pyg_data, hidden_channels, out_channels, num_layers, gnn_backbone, gnn_aggr, snp_init_dim_size, gene_init_dim_size, go_init_dim_size, gat_num_head, no_relu = False):
26
        super().__init__()
27
        edge_types = pyg_data.edge_types
28
        self.convs = torch.nn.ModuleList()
29
        
30
        self.snp_feat_mlp = SimpleMLP(snp_init_dim_size, hidden_channels, hidden_channels)
31
        self.go_feat_mlp = SimpleMLP(go_init_dim_size, hidden_channels, hidden_channels)
32
        self.gene_feat_mlp = SimpleMLP(gene_init_dim_size, hidden_channels, hidden_channels)
33
        self.ReLU = nn.ReLU()   
34
        for _ in range(num_layers):
35
            conv_layer = {}
36
            for i in edge_types:
37
                if gnn_backbone == 'SAGE':
38
                    conv_layer[i] = SAGEConv((-1, -1), hidden_channels)
39
                elif gnn_backbone == 'GAT':
40
                    conv_layer[i] = GATConv((-1, -1), hidden_channels, 
41
                                            heads = gat_num_head, 
42
                                            add_self_loops = False)
43
                elif gnn_backbone == 'GCN':
44
                    conv_layer[i] = GCNConv(-1, hidden_channels, add_self_loops = False)
45
                elif gnn_backbone == 'SGC':
46
                    conv_layer[i] = SGConv(-1, hidden_channels, add_self_loops = False)
47
            conv = HeteroConv(conv_layer, aggr=gnn_aggr)
48
            self.convs.append(conv)
49
        
50
        self.lin = Linear(hidden_channels, out_channels)
51
        self.no_relu = no_relu
52
        
53
    def forward(self, x_dict, edge_index_dict, batch_size, genotype = None, return_h = False, 
54
                return_attention_weights = False):
55
        
56
        x_dict['SNP'] = self.snp_feat_mlp(x_dict['SNP'])
57
        x_dict['Gene'] = self.gene_feat_mlp(x_dict['Gene'])
58
        x_dict['CellularComponent'] = self.go_feat_mlp(x_dict['CellularComponent'])
59
        x_dict['BiologicalProcess'] = self.go_feat_mlp(x_dict['BiologicalProcess'])
60
        x_dict['MolecularFunction'] = self.go_feat_mlp(x_dict['MolecularFunction'])
61
        
62
        
63
        attention_all_layers = []
64
        for conv in self.convs:
65
            if return_attention_weights:
66
                out = conv(x_dict, edge_index_dict, 
67
                              return_attention_weights_dict = dict(zip(list(edge_index_dict.keys()), 
68
                                                            [True] * len(list(edge_index_dict.keys())))))
69
                #attention_layer = {i: [x[1] for x in j[1]] for i,j in out.items()}
70
                mean_attention = torch.mean(torch.vstack([torch.vstack([x[1] for x in j[1]]) for i,j in out.items()]))
71
                x_dict = {i: j[0] for i,j in out.items()}
72
                attention_all_layers.append(mean_attention)
73
            else:
74
                x_dict = conv(x_dict, edge_index_dict)
75
            x_dict = {key: x.relu() for key, x in x_dict.items()}
76
        
77
        
78
        if return_h:
79
            return self.ReLU(self.lin(x_dict['SNP']))[:batch_size], x_dict['SNP'][:batch_size] 
80
        if return_attention_weights:
81
            return self.ReLU(self.lin(x_dict['SNP']))[:batch_size], attention_all_layers
82
        else:
83
            if self.no_relu:
84
                return self.lin(x_dict['SNP'])[:batch_size]
85
            else:
86
                return self.ReLU(self.lin(x_dict['SNP']))[:batch_size]