[b52eda]: / Network.py

Download this file

61 lines (54 with data), 2.6 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from torch.nn import Module
from torch_geometric.nn import GATv2Conv
class CHD_GNN(Module):
r"""
PyTorch Geometric GNN used for coronary-CT segmentation \
of images with visible CHDs.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# GAT layers
self.gat_1_to_1 = GATv2Conv(1, 1,
fill_value = 'sum',
dropout = 0.25)
self.gat_1_to_2 = GATv2Conv(1, 2,
fill_value = 'sum')
self.gat_2_to_2_n1 = GATv2Conv(2, 2,
fill_value = 'sum',
dropout = 0.25)
self.gat_2_to_2_n2 = GATv2Conv(2, 2,
fill_value = 'sum',
dropout = 0.25)
self.gat_2_to_4 = GATv2Conv(2, 4,
fill_value = 'sum')
self.gat_4_to_4_n1 = GATv2Conv(4, 4,
fill_value = 'sum',
dropout = 0.25)
self.gat_4_to_4_n2 = GATv2Conv(4, 4,
fill_value = 'sum',
dropout = 0.25)
self.gat_4_to_8 = GATv2Conv(4, 8,
fill_value = 'sum')
self.gat_8_to_8 = GATv2Conv(8, 8,
fill_value = 'sum')
def forward(self, x, adj_matrix):
r"""
Arguments:
x (Tensor): Source coronary-CT image as a graph.
adj_matrix (Tensor): Adjacency matrix of the x graph.
Returns:
out (Tensor): Segmentation result as a graph.
"""
out = self.gat_1_to_1(x = x, edge_index = adj_matrix)
out = self.gat_1_to_2(x = out, edge_index = adj_matrix)
out = out.tanh()
out = self.gat_2_to_2_n1(x = out, edge_index = adj_matrix)
out = self.gat_2_to_2_n2(x = out, edge_index = adj_matrix)
out = self.gat_2_to_4(x = out, edge_index = adj_matrix)
out = out.tanh()
out = self.gat_4_to_4_n1(x = out, edge_index = adj_matrix)
out = self.gat_4_to_4_n2(x = out, edge_index = adj_matrix)
out = self.gat_4_to_8(x = out, edge_index = adj_matrix)
out = out.tanh()
out = self.gat_8_to_8(x = out, edge_index = adj_matrix)
return out