Diff of /Network.py [000000] .. [b52eda]

Switch to unified view

a b/Network.py
1
from torch.nn import Module
2
from torch_geometric.nn import GATv2Conv
3
4
class CHD_GNN(Module):
5
    r"""
6
        PyTorch Geometric GNN used for coronary-CT segmentation \
7
        of images with visible CHDs.
8
    """
9
    def __init__(self, *args, **kwargs):
10
        super().__init__(*args, **kwargs)
11
12
        # GAT layers
13
        self.gat_1_to_1 =       GATv2Conv(1, 1,
14
                                          fill_value = 'sum',
15
                                          dropout = 0.25)
16
        self.gat_1_to_2 =       GATv2Conv(1, 2,
17
                                          fill_value = 'sum')
18
        self.gat_2_to_2_n1 =    GATv2Conv(2, 2,
19
                                          fill_value = 'sum',
20
                                          dropout = 0.25)
21
        self.gat_2_to_2_n2 =    GATv2Conv(2, 2,
22
                                          fill_value = 'sum',
23
                                          dropout = 0.25)
24
        self.gat_2_to_4 =       GATv2Conv(2, 4,
25
                                          fill_value = 'sum')
26
        self.gat_4_to_4_n1 =    GATv2Conv(4, 4,
27
                                          fill_value = 'sum',
28
                                          dropout = 0.25)
29
        self.gat_4_to_4_n2 =    GATv2Conv(4, 4,
30
                                          fill_value = 'sum',
31
                                          dropout = 0.25)
32
        self.gat_4_to_8 =       GATv2Conv(4, 8,
33
                                          fill_value = 'sum')
34
        self.gat_8_to_8 =       GATv2Conv(8, 8,
35
                                          fill_value = 'sum')
36
37
    def forward(self, x, adj_matrix):
38
        r"""
39
            Arguments:
40
                x (Tensor): Source coronary-CT image as a graph.
41
                adj_matrix (Tensor): Adjacency matrix of the x graph.
42
            
43
            Returns:
44
                out (Tensor): Segmentation result as a graph.
45
        """
46
        out = self.gat_1_to_1(x = x, edge_index = adj_matrix)
47
        out = self.gat_1_to_2(x = out, edge_index = adj_matrix)
48
        out = out.tanh()
49
50
        out = self.gat_2_to_2_n1(x = out, edge_index = adj_matrix)
51
        out = self.gat_2_to_2_n2(x = out, edge_index = adj_matrix)
52
        out = self.gat_2_to_4(x = out, edge_index = adj_matrix)
53
        out = out.tanh()
54
55
        out = self.gat_4_to_4_n1(x = out, edge_index = adj_matrix)
56
        out = self.gat_4_to_4_n2(x = out, edge_index = adj_matrix)
57
        out = self.gat_4_to_8(x = out, edge_index = adj_matrix)
58
        out = out.tanh()
59
60
        out = self.gat_8_to_8(x = out, edge_index = adj_matrix)
61
        return out