Diff of /Network.py [000000] .. [b52eda]

Switch to side-by-side view

--- a
+++ b/Network.py
@@ -0,0 +1,61 @@
+from torch.nn import Module
+from torch_geometric.nn import GATv2Conv
+
+class CHD_GNN(Module):
+    r"""
+        PyTorch Geometric GNN used for coronary-CT segmentation \
+        of images with visible CHDs.
+    """
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # GAT layers
+        self.gat_1_to_1 =       GATv2Conv(1, 1,
+                                          fill_value = 'sum',
+                                          dropout = 0.25)
+        self.gat_1_to_2 =       GATv2Conv(1, 2,
+                                          fill_value = 'sum')
+        self.gat_2_to_2_n1 =    GATv2Conv(2, 2,
+                                          fill_value = 'sum',
+                                          dropout = 0.25)
+        self.gat_2_to_2_n2 =    GATv2Conv(2, 2,
+                                          fill_value = 'sum',
+                                          dropout = 0.25)
+        self.gat_2_to_4 =       GATv2Conv(2, 4,
+                                          fill_value = 'sum')
+        self.gat_4_to_4_n1 =    GATv2Conv(4, 4,
+                                          fill_value = 'sum',
+                                          dropout = 0.25)
+        self.gat_4_to_4_n2 =    GATv2Conv(4, 4,
+                                          fill_value = 'sum',
+                                          dropout = 0.25)
+        self.gat_4_to_8 =       GATv2Conv(4, 8,
+                                          fill_value = 'sum')
+        self.gat_8_to_8 =       GATv2Conv(8, 8,
+                                          fill_value = 'sum')
+
+    def forward(self, x, adj_matrix):
+        r"""
+            Arguments:
+                x (Tensor): Source coronary-CT image as a graph.
+                adj_matrix (Tensor): Adjacency matrix of the x graph.
+            
+            Returns:
+                out (Tensor): Segmentation result as a graph.
+        """
+        out = self.gat_1_to_1(x = x, edge_index = adj_matrix)
+        out = self.gat_1_to_2(x = out, edge_index = adj_matrix)
+        out = out.tanh()
+
+        out = self.gat_2_to_2_n1(x = out, edge_index = adj_matrix)
+        out = self.gat_2_to_2_n2(x = out, edge_index = adj_matrix)
+        out = self.gat_2_to_4(x = out, edge_index = adj_matrix)
+        out = out.tanh()
+
+        out = self.gat_4_to_4_n1(x = out, edge_index = adj_matrix)
+        out = self.gat_4_to_4_n2(x = out, edge_index = adj_matrix)
+        out = self.gat_4_to_8(x = out, edge_index = adj_matrix)
+        out = out.tanh()
+
+        out = self.gat_8_to_8(x = out, edge_index = adj_matrix)
+        return out
\ No newline at end of file