|
a |
|
b/kgwas/model.py |
|
|
1 |
from torch_geometric.nn import Linear, SAGEConv, GCNConv, SGConv, Sequential, to_hetero, HeteroConv |
|
|
2 |
import torch |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
import torch.optim as optim |
|
|
5 |
import torch.nn as nn |
|
|
6 |
|
|
|
7 |
from .conv import GATConv |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
class SimpleMLP(nn.Module): |
|
|
11 |
def __init__(self, input_dim, hidden_dim, output_dim): |
|
|
12 |
super(SimpleMLP, self).__init__() |
|
|
13 |
self.FC_hidden = nn.Linear(input_dim, hidden_dim) |
|
|
14 |
self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim) |
|
|
15 |
self.FC_output = nn.Linear(hidden_dim, output_dim) |
|
|
16 |
self.ReLU = nn.ReLU() |
|
|
17 |
|
|
|
18 |
def forward(self, x): |
|
|
19 |
h = self.ReLU(self.FC_hidden(x)) |
|
|
20 |
h = self.ReLU(self.FC_hidden2(h)) |
|
|
21 |
x_hat = self.FC_output(h) |
|
|
22 |
return x_hat |
|
|
23 |
|
|
|
24 |
class HeteroGNN(torch.nn.Module): |
|
|
25 |
def __init__(self, pyg_data, hidden_channels, out_channels, num_layers, gnn_backbone, gnn_aggr, snp_init_dim_size, gene_init_dim_size, go_init_dim_size, gat_num_head, no_relu = False): |
|
|
26 |
super().__init__() |
|
|
27 |
edge_types = pyg_data.edge_types |
|
|
28 |
self.convs = torch.nn.ModuleList() |
|
|
29 |
|
|
|
30 |
self.snp_feat_mlp = SimpleMLP(snp_init_dim_size, hidden_channels, hidden_channels) |
|
|
31 |
self.go_feat_mlp = SimpleMLP(go_init_dim_size, hidden_channels, hidden_channels) |
|
|
32 |
self.gene_feat_mlp = SimpleMLP(gene_init_dim_size, hidden_channels, hidden_channels) |
|
|
33 |
self.ReLU = nn.ReLU() |
|
|
34 |
for _ in range(num_layers): |
|
|
35 |
conv_layer = {} |
|
|
36 |
for i in edge_types: |
|
|
37 |
if gnn_backbone == 'SAGE': |
|
|
38 |
conv_layer[i] = SAGEConv((-1, -1), hidden_channels) |
|
|
39 |
elif gnn_backbone == 'GAT': |
|
|
40 |
conv_layer[i] = GATConv((-1, -1), hidden_channels, |
|
|
41 |
heads = gat_num_head, |
|
|
42 |
add_self_loops = False) |
|
|
43 |
elif gnn_backbone == 'GCN': |
|
|
44 |
conv_layer[i] = GCNConv(-1, hidden_channels, add_self_loops = False) |
|
|
45 |
elif gnn_backbone == 'SGC': |
|
|
46 |
conv_layer[i] = SGConv(-1, hidden_channels, add_self_loops = False) |
|
|
47 |
conv = HeteroConv(conv_layer, aggr=gnn_aggr) |
|
|
48 |
self.convs.append(conv) |
|
|
49 |
|
|
|
50 |
self.lin = Linear(hidden_channels, out_channels) |
|
|
51 |
self.no_relu = no_relu |
|
|
52 |
|
|
|
53 |
def forward(self, x_dict, edge_index_dict, batch_size, genotype = None, return_h = False, |
|
|
54 |
return_attention_weights = False): |
|
|
55 |
|
|
|
56 |
x_dict['SNP'] = self.snp_feat_mlp(x_dict['SNP']) |
|
|
57 |
x_dict['Gene'] = self.gene_feat_mlp(x_dict['Gene']) |
|
|
58 |
x_dict['CellularComponent'] = self.go_feat_mlp(x_dict['CellularComponent']) |
|
|
59 |
x_dict['BiologicalProcess'] = self.go_feat_mlp(x_dict['BiologicalProcess']) |
|
|
60 |
x_dict['MolecularFunction'] = self.go_feat_mlp(x_dict['MolecularFunction']) |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
attention_all_layers = [] |
|
|
64 |
for conv in self.convs: |
|
|
65 |
if return_attention_weights: |
|
|
66 |
out = conv(x_dict, edge_index_dict, |
|
|
67 |
return_attention_weights_dict = dict(zip(list(edge_index_dict.keys()), |
|
|
68 |
[True] * len(list(edge_index_dict.keys()))))) |
|
|
69 |
#attention_layer = {i: [x[1] for x in j[1]] for i,j in out.items()} |
|
|
70 |
mean_attention = torch.mean(torch.vstack([torch.vstack([x[1] for x in j[1]]) for i,j in out.items()])) |
|
|
71 |
x_dict = {i: j[0] for i,j in out.items()} |
|
|
72 |
attention_all_layers.append(mean_attention) |
|
|
73 |
else: |
|
|
74 |
x_dict = conv(x_dict, edge_index_dict) |
|
|
75 |
x_dict = {key: x.relu() for key, x in x_dict.items()} |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
if return_h: |
|
|
79 |
return self.ReLU(self.lin(x_dict['SNP']))[:batch_size], x_dict['SNP'][:batch_size] |
|
|
80 |
if return_attention_weights: |
|
|
81 |
return self.ReLU(self.lin(x_dict['SNP']))[:batch_size], attention_all_layers |
|
|
82 |
else: |
|
|
83 |
if self.no_relu: |
|
|
84 |
return self.lin(x_dict['SNP'])[:batch_size] |
|
|
85 |
else: |
|
|
86 |
return self.ReLU(self.lin(x_dict['SNP']))[:batch_size] |