--- a +++ b/kgwas/model.py @@ -0,0 +1,86 @@ +from torch_geometric.nn import Linear, SAGEConv, GCNConv, SGConv, Sequential, to_hetero, HeteroConv +import torch +import torch.nn.functional as F +import torch.optim as optim +import torch.nn as nn + +from .conv import GATConv + + +class SimpleMLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super(SimpleMLP, self).__init__() + self.FC_hidden = nn.Linear(input_dim, hidden_dim) + self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim) + self.FC_output = nn.Linear(hidden_dim, output_dim) + self.ReLU = nn.ReLU() + + def forward(self, x): + h = self.ReLU(self.FC_hidden(x)) + h = self.ReLU(self.FC_hidden2(h)) + x_hat = self.FC_output(h) + return x_hat + +class HeteroGNN(torch.nn.Module): + def __init__(self, pyg_data, hidden_channels, out_channels, num_layers, gnn_backbone, gnn_aggr, snp_init_dim_size, gene_init_dim_size, go_init_dim_size, gat_num_head, no_relu = False): + super().__init__() + edge_types = pyg_data.edge_types + self.convs = torch.nn.ModuleList() + + self.snp_feat_mlp = SimpleMLP(snp_init_dim_size, hidden_channels, hidden_channels) + self.go_feat_mlp = SimpleMLP(go_init_dim_size, hidden_channels, hidden_channels) + self.gene_feat_mlp = SimpleMLP(gene_init_dim_size, hidden_channels, hidden_channels) + self.ReLU = nn.ReLU() + for _ in range(num_layers): + conv_layer = {} + for i in edge_types: + if gnn_backbone == 'SAGE': + conv_layer[i] = SAGEConv((-1, -1), hidden_channels) + elif gnn_backbone == 'GAT': + conv_layer[i] = GATConv((-1, -1), hidden_channels, + heads = gat_num_head, + add_self_loops = False) + elif gnn_backbone == 'GCN': + conv_layer[i] = GCNConv(-1, hidden_channels, add_self_loops = False) + elif gnn_backbone == 'SGC': + conv_layer[i] = SGConv(-1, hidden_channels, add_self_loops = False) + conv = HeteroConv(conv_layer, aggr=gnn_aggr) + self.convs.append(conv) + + self.lin = Linear(hidden_channels, out_channels) + self.no_relu = no_relu + + def forward(self, x_dict, edge_index_dict, batch_size, genotype = None, return_h = False, + return_attention_weights = False): + + x_dict['SNP'] = self.snp_feat_mlp(x_dict['SNP']) + x_dict['Gene'] = self.gene_feat_mlp(x_dict['Gene']) + x_dict['CellularComponent'] = self.go_feat_mlp(x_dict['CellularComponent']) + x_dict['BiologicalProcess'] = self.go_feat_mlp(x_dict['BiologicalProcess']) + x_dict['MolecularFunction'] = self.go_feat_mlp(x_dict['MolecularFunction']) + + + attention_all_layers = [] + for conv in self.convs: + if return_attention_weights: + out = conv(x_dict, edge_index_dict, + return_attention_weights_dict = dict(zip(list(edge_index_dict.keys()), + [True] * len(list(edge_index_dict.keys()))))) + #attention_layer = {i: [x[1] for x in j[1]] for i,j in out.items()} + mean_attention = torch.mean(torch.vstack([torch.vstack([x[1] for x in j[1]]) for i,j in out.items()])) + x_dict = {i: j[0] for i,j in out.items()} + attention_all_layers.append(mean_attention) + else: + x_dict = conv(x_dict, edge_index_dict) + x_dict = {key: x.relu() for key, x in x_dict.items()} + + + if return_h: + return self.ReLU(self.lin(x_dict['SNP']))[:batch_size], x_dict['SNP'][:batch_size] + if return_attention_weights: + return self.ReLU(self.lin(x_dict['SNP']))[:batch_size], attention_all_layers + else: + if self.no_relu: + return self.lin(x_dict['SNP'])[:batch_size] + else: + return self.ReLU(self.lin(x_dict['SNP']))[:batch_size]