|
a |
|
b/kgwas/conv.py |
|
|
1 |
from typing import Optional, Tuple, Union |
|
|
2 |
|
|
|
3 |
import torch |
|
|
4 |
import torch.nn.functional as F |
|
|
5 |
from torch import Tensor |
|
|
6 |
from torch.nn import Parameter |
|
|
7 |
from torch_sparse import SparseTensor, set_diag |
|
|
8 |
|
|
|
9 |
from torch_geometric.nn.conv import MessagePassing |
|
|
10 |
from torch_geometric.nn.dense.linear import Linear |
|
|
11 |
from torch_geometric.typing import NoneType # noqa |
|
|
12 |
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size |
|
|
13 |
from torch_geometric.utils import add_self_loops, remove_self_loops, softmax |
|
|
14 |
from torch_geometric.nn.inits import glorot, zeros |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
''' |
|
|
18 |
def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]: |
|
|
19 |
if len(xs) == 0: |
|
|
20 |
return None |
|
|
21 |
elif aggr is None: |
|
|
22 |
return torch.stack(xs, dim=1) |
|
|
23 |
elif len(xs) == 1: |
|
|
24 |
return xs[0] |
|
|
25 |
elif isinstance(xs[0], tuple): |
|
|
26 |
return xs |
|
|
27 |
else: |
|
|
28 |
out = torch.stack(xs, dim=0) |
|
|
29 |
out = getattr(torch, aggr)(out, dim=0) |
|
|
30 |
out = out[0] if isinstance(out, tuple) else out |
|
|
31 |
return out |
|
|
32 |
''' |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
|
|
|
36 |
class GATConv(MessagePassing): |
|
|
37 |
def __init__( |
|
|
38 |
self, |
|
|
39 |
in_channels: Union[int, Tuple[int, int]], |
|
|
40 |
out_channels: int, |
|
|
41 |
heads: int = 1, |
|
|
42 |
concat: bool = True, |
|
|
43 |
negative_slope: float = 0.2, |
|
|
44 |
dropout: float = 0.0, |
|
|
45 |
add_self_loops: bool = True, |
|
|
46 |
edge_dim: Optional[int] = None, |
|
|
47 |
fill_value: Union[float, Tensor, str] = 'mean', |
|
|
48 |
bias: bool = True, |
|
|
49 |
sigmoid_gat: bool = False, |
|
|
50 |
temperature: float = 1, |
|
|
51 |
pheno_condition: bool = False, |
|
|
52 |
**kwargs, |
|
|
53 |
): |
|
|
54 |
kwargs.setdefault('aggr', 'add') |
|
|
55 |
super().__init__(node_dim=0, **kwargs) |
|
|
56 |
|
|
|
57 |
self.in_channels = in_channels |
|
|
58 |
self.out_channels = out_channels |
|
|
59 |
self.heads = heads |
|
|
60 |
self.concat = concat |
|
|
61 |
self.negative_slope = negative_slope |
|
|
62 |
self.dropout = dropout |
|
|
63 |
self.add_self_loops = add_self_loops |
|
|
64 |
self.edge_dim = edge_dim |
|
|
65 |
self.fill_value = fill_value |
|
|
66 |
self.sigmoid_gat = sigmoid_gat |
|
|
67 |
self.temperature = temperature |
|
|
68 |
self.pheno_condition = pheno_condition |
|
|
69 |
|
|
|
70 |
if self.pheno_condition == 'ATT': |
|
|
71 |
self.lin_edge_ = Linear(self.out_channels, heads * out_channels, bias=False, |
|
|
72 |
weight_initializer='glorot') |
|
|
73 |
self.att_edge = Parameter(torch.Tensor(1, heads, out_channels)) |
|
|
74 |
|
|
|
75 |
elif self.pheno_condition == 'MSG': |
|
|
76 |
self.pheno_mlp = Linear(edge_dim, heads * out_channels, bias=False, |
|
|
77 |
weight_initializer='glorot') |
|
|
78 |
|
|
|
79 |
# In case we are operating in bipartite graphs, we apply separate |
|
|
80 |
# transformations 'lin_src' and 'lin_dst' to source and target nodes: |
|
|
81 |
if isinstance(in_channels, int): |
|
|
82 |
self.lin_src = Linear(in_channels, heads * out_channels, |
|
|
83 |
bias=False, weight_initializer='glorot') |
|
|
84 |
self.lin_dst = self.lin_src |
|
|
85 |
else: |
|
|
86 |
self.lin_src = Linear(in_channels[0], heads * out_channels, False, |
|
|
87 |
weight_initializer='glorot') |
|
|
88 |
self.lin_dst = Linear(in_channels[1], heads * out_channels, False, |
|
|
89 |
weight_initializer='glorot') |
|
|
90 |
|
|
|
91 |
# The learnable parameters to compute attention coefficients: |
|
|
92 |
self.att_src = Parameter(torch.Tensor(1, heads, out_channels)) |
|
|
93 |
self.att_dst = Parameter(torch.Tensor(1, heads, out_channels)) |
|
|
94 |
|
|
|
95 |
if edge_dim is not None: |
|
|
96 |
self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False, |
|
|
97 |
weight_initializer='glorot') |
|
|
98 |
self.att_edge = Parameter(torch.Tensor(1, heads, out_channels)) |
|
|
99 |
else: |
|
|
100 |
self.lin_edge = None |
|
|
101 |
self.register_parameter('att_edge', None) |
|
|
102 |
|
|
|
103 |
if bias and concat: |
|
|
104 |
self.bias = Parameter(torch.Tensor(heads * out_channels)) |
|
|
105 |
elif bias and not concat: |
|
|
106 |
self.bias = Parameter(torch.Tensor(out_channels)) |
|
|
107 |
else: |
|
|
108 |
self.register_parameter('bias', None) |
|
|
109 |
|
|
|
110 |
self.reset_parameters() |
|
|
111 |
|
|
|
112 |
def reset_parameters(self): |
|
|
113 |
self.lin_src.reset_parameters() |
|
|
114 |
self.lin_dst.reset_parameters() |
|
|
115 |
if self.lin_edge is not None: |
|
|
116 |
self.lin_edge.reset_parameters() |
|
|
117 |
glorot(self.att_src) |
|
|
118 |
glorot(self.att_dst) |
|
|
119 |
glorot(self.att_edge) |
|
|
120 |
zeros(self.bias) |
|
|
121 |
|
|
|
122 |
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, |
|
|
123 |
edge_attr: OptTensor = None, size: Size = None, pheno_emb = None, |
|
|
124 |
return_attention_weights=None, return_raw_attention_weights = None): |
|
|
125 |
|
|
|
126 |
|
|
|
127 |
if return_raw_attention_weights: |
|
|
128 |
self.return_raw_attention_weights = True |
|
|
129 |
else: |
|
|
130 |
self.return_raw_attention_weights = False |
|
|
131 |
|
|
|
132 |
H, C = self.heads, self.out_channels |
|
|
133 |
|
|
|
134 |
# We first transform the input node features. If a tuple is passed, we |
|
|
135 |
# transform source and target node features via separate weights: |
|
|
136 |
if isinstance(x, Tensor): |
|
|
137 |
assert x.dim() == 2, "Static graphs not supported in 'GATConv'" |
|
|
138 |
x_src = x_dst = self.lin_src(x).view(-1, H, C) |
|
|
139 |
else: # Tuple of source and target node features: |
|
|
140 |
x_src, x_dst = x |
|
|
141 |
assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" |
|
|
142 |
x_src = self.lin_src(x_src).view(-1, H, C) |
|
|
143 |
if x_dst is not None: |
|
|
144 |
x_dst = self.lin_dst(x_dst).view(-1, H, C) |
|
|
145 |
|
|
|
146 |
x = (x_src, x_dst) |
|
|
147 |
|
|
|
148 |
# Next, we compute node-level attention coefficients, both for source |
|
|
149 |
# and target nodes (if present): |
|
|
150 |
alpha_src = (x_src * self.att_src).sum(dim=-1) |
|
|
151 |
alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) |
|
|
152 |
alpha = (alpha_src, alpha_dst) |
|
|
153 |
|
|
|
154 |
if self.add_self_loops: |
|
|
155 |
if isinstance(edge_index, Tensor): |
|
|
156 |
# We only want to add self-loops for nodes that appear both as |
|
|
157 |
# source and target nodes: |
|
|
158 |
num_nodes = x_src.size(0) |
|
|
159 |
if x_dst is not None: |
|
|
160 |
num_nodes = min(num_nodes, x_dst.size(0)) |
|
|
161 |
num_nodes = min(size) if size is not None else num_nodes |
|
|
162 |
edge_index, edge_attr = remove_self_loops( |
|
|
163 |
edge_index, edge_attr) |
|
|
164 |
edge_index, edge_attr = add_self_loops( |
|
|
165 |
edge_index, edge_attr, fill_value=self.fill_value, |
|
|
166 |
num_nodes=num_nodes) |
|
|
167 |
elif isinstance(edge_index, SparseTensor): |
|
|
168 |
if self.edge_dim is None: |
|
|
169 |
edge_index = set_diag(edge_index) |
|
|
170 |
else: |
|
|
171 |
raise NotImplementedError( |
|
|
172 |
"The usage of 'edge_attr' and 'add_self_loops' " |
|
|
173 |
"simultaneously is currently not yet supported for " |
|
|
174 |
"'edge_index' in a 'SparseTensor' form") |
|
|
175 |
|
|
|
176 |
# edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor) |
|
|
177 |
alpha = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr) |
|
|
178 |
|
|
|
179 |
#if self.return_raw_attention_weights: |
|
|
180 |
# return 0, (edge_index, alpha) |
|
|
181 |
# propagate_type: (x: OptPairTensor, alpha: Tensor) |
|
|
182 |
out = self.propagate(edge_index, x=x, alpha=alpha, size=size) |
|
|
183 |
|
|
|
184 |
if self.concat: |
|
|
185 |
out = out.view(-1, self.heads * self.out_channels) |
|
|
186 |
else: |
|
|
187 |
out = out.mean(dim=1) |
|
|
188 |
|
|
|
189 |
if self.bias is not None: |
|
|
190 |
out = out + self.bias |
|
|
191 |
|
|
|
192 |
if isinstance(return_attention_weights, bool): |
|
|
193 |
if isinstance(edge_index, Tensor): |
|
|
194 |
return out, (edge_index, alpha) |
|
|
195 |
elif isinstance(edge_index, SparseTensor): |
|
|
196 |
return out, edge_index.set_value(alpha, layout='coo') |
|
|
197 |
else: |
|
|
198 |
return out |
|
|
199 |
|
|
|
200 |
def edge_update(self, alpha_j: Tensor, alpha_i: OptTensor, |
|
|
201 |
edge_attr: OptTensor, index: Tensor, ptr: OptTensor, |
|
|
202 |
size_i: Optional[int]) -> Tensor: |
|
|
203 |
# Given edge-level attention coefficients for source and target nodes, |
|
|
204 |
# we simply need to sum them up to "emulate" concatenation: |
|
|
205 |
alpha = alpha_j if alpha_i is None else alpha_j + alpha_i |
|
|
206 |
|
|
|
207 |
if edge_attr is not None and self.lin_edge is not None: |
|
|
208 |
if edge_attr.dim() == 1: |
|
|
209 |
edge_attr = edge_attr.view(-1, 1) |
|
|
210 |
#print(edge_attr) |
|
|
211 |
#print(edge_attr.shape) |
|
|
212 |
edge_attr = self.lin_edge(edge_attr) |
|
|
213 |
edge_attr = edge_attr.view(-1, self.heads, self.out_channels) |
|
|
214 |
alpha_edge = (edge_attr * self.att_edge).sum(dim=-1) |
|
|
215 |
alpha = alpha + alpha_edge |
|
|
216 |
|
|
|
217 |
alpha = F.leaky_relu(alpha, self.negative_slope) |
|
|
218 |
|
|
|
219 |
if self.sigmoid_gat: |
|
|
220 |
alpha = torch.sigmoid(alpha/self.temperature) |
|
|
221 |
else: |
|
|
222 |
if not self.return_raw_attention_weights: |
|
|
223 |
alpha = softmax(alpha/self.temperature, index, ptr, size_i) |
|
|
224 |
alpha = F.dropout(alpha, p=self.dropout, training=self.training) |
|
|
225 |
return alpha |
|
|
226 |
|
|
|
227 |
def message(self, x_j: Tensor, alpha: Tensor) -> Tensor: |
|
|
228 |
return alpha.unsqueeze(-1) * x_j |
|
|
229 |
|
|
|
230 |
def __repr__(self) -> str: |
|
|
231 |
return (f'{self.__class__.__name__}({self.in_channels}, ' |
|
|
232 |
f'{self.out_channels}, heads={self.heads})') |