|
a |
|
b/src/model/layers.py |
|
|
1 |
import math |
|
|
2 |
|
|
|
3 |
import torch |
|
|
4 |
import torch.nn as nn |
|
|
5 |
from torch.nn import functional as F |
|
|
6 |
|
|
|
7 |
class MLP(nn.Module): |
|
|
8 |
""" |
|
|
9 |
A simple Multi-Layer Perceptron (MLP) module consisting of two linear layers with a ReLU activation in between, |
|
|
10 |
followed by a dropout on the output. |
|
|
11 |
|
|
|
12 |
Attributes: |
|
|
13 |
fc1 (nn.Linear): The first fully-connected layer. |
|
|
14 |
act (nn.ReLU): ReLU activation function. |
|
|
15 |
fc2 (nn.Linear): The second fully-connected layer. |
|
|
16 |
droprateout (nn.Dropout): Dropout layer applied to the output. |
|
|
17 |
""" |
|
|
18 |
def __init__(self, in_feat, hid_feat=None, out_feat=None, dropout=0.): |
|
|
19 |
""" |
|
|
20 |
Initializes the MLP module. |
|
|
21 |
|
|
|
22 |
Args: |
|
|
23 |
in_feat (int): Number of input features. |
|
|
24 |
hid_feat (int, optional): Number of hidden features. Defaults to in_feat if not provided. |
|
|
25 |
out_feat (int, optional): Number of output features. Defaults to in_feat if not provided. |
|
|
26 |
dropout (float, optional): Dropout rate. Defaults to 0. |
|
|
27 |
""" |
|
|
28 |
super().__init__() |
|
|
29 |
|
|
|
30 |
# Set hidden and output dimensions to input dimension if not specified |
|
|
31 |
if not hid_feat: |
|
|
32 |
hid_feat = in_feat |
|
|
33 |
if not out_feat: |
|
|
34 |
out_feat = in_feat |
|
|
35 |
|
|
|
36 |
self.fc1 = nn.Linear(in_feat, hid_feat) |
|
|
37 |
self.act = nn.ReLU() |
|
|
38 |
self.fc2 = nn.Linear(hid_feat, out_feat) |
|
|
39 |
self.droprateout = nn.Dropout(dropout) |
|
|
40 |
|
|
|
41 |
def forward(self, x): |
|
|
42 |
""" |
|
|
43 |
Forward pass for the MLP. |
|
|
44 |
|
|
|
45 |
Args: |
|
|
46 |
x (torch.Tensor): Input tensor. |
|
|
47 |
|
|
|
48 |
Returns: |
|
|
49 |
torch.Tensor: Output tensor after applying the linear layers, activation, and dropout. |
|
|
50 |
""" |
|
|
51 |
x = self.fc1(x) |
|
|
52 |
x = self.act(x) |
|
|
53 |
x = self.fc2(x) |
|
|
54 |
return self.droprateout(x) |
|
|
55 |
|
|
|
56 |
class MHA(nn.Module): |
|
|
57 |
""" |
|
|
58 |
Multi-Head Attention (MHA) module of the graph transformer with edge features incorporated into the attention computation. |
|
|
59 |
|
|
|
60 |
Attributes: |
|
|
61 |
heads (int): Number of attention heads. |
|
|
62 |
scale (float): Scaling factor for the attention scores. |
|
|
63 |
q, k, v (nn.Linear): Linear layers to project the node features into query, key, and value embeddings. |
|
|
64 |
e (nn.Linear): Linear layer to project the edge features. |
|
|
65 |
d_k (int): Dimension of each attention head. |
|
|
66 |
out_e (nn.Linear): Linear layer applied to the computed edge features. |
|
|
67 |
out_n (nn.Linear): Linear layer applied to the aggregated node features. |
|
|
68 |
""" |
|
|
69 |
def __init__(self, dim, heads, attention_dropout=0.): |
|
|
70 |
""" |
|
|
71 |
Initializes the Multi-Head Attention module. |
|
|
72 |
|
|
|
73 |
Args: |
|
|
74 |
dim (int): Dimensionality of the input features. |
|
|
75 |
heads (int): Number of attention heads. |
|
|
76 |
attention_dropout (float, optional): Dropout rate for attention (not used explicitly in this implementation). |
|
|
77 |
""" |
|
|
78 |
super().__init__() |
|
|
79 |
|
|
|
80 |
# Ensure that dimension is divisible by the number of heads |
|
|
81 |
assert dim % heads == 0 |
|
|
82 |
|
|
|
83 |
self.heads = heads |
|
|
84 |
self.scale = 1. / math.sqrt(dim) # Scaling factor for attention |
|
|
85 |
# Linear layers for projecting node features |
|
|
86 |
self.q = nn.Linear(dim, dim) |
|
|
87 |
self.k = nn.Linear(dim, dim) |
|
|
88 |
self.v = nn.Linear(dim, dim) |
|
|
89 |
# Linear layer for projecting edge features |
|
|
90 |
self.e = nn.Linear(dim, dim) |
|
|
91 |
self.d_k = dim // heads # Dimension per head |
|
|
92 |
|
|
|
93 |
# Linear layers for output transformations |
|
|
94 |
self.out_e = nn.Linear(dim, dim) |
|
|
95 |
self.out_n = nn.Linear(dim, dim) |
|
|
96 |
|
|
|
97 |
def forward(self, node, edge): |
|
|
98 |
""" |
|
|
99 |
Forward pass for the Multi-Head Attention. |
|
|
100 |
|
|
|
101 |
Args: |
|
|
102 |
node (torch.Tensor): Node feature tensor of shape (batch, num_nodes, dim). |
|
|
103 |
edge (torch.Tensor): Edge feature tensor of shape (batch, num_nodes, num_nodes, dim). |
|
|
104 |
|
|
|
105 |
Returns: |
|
|
106 |
tuple: (updated node features, updated edge features) |
|
|
107 |
""" |
|
|
108 |
b, n, c = node.shape |
|
|
109 |
|
|
|
110 |
# Compute query, key, and value embeddings and reshape for multi-head attention |
|
|
111 |
q_embed = self.q(node).view(b, n, self.heads, c // self.heads) |
|
|
112 |
k_embed = self.k(node).view(b, n, self.heads, c // self.heads) |
|
|
113 |
v_embed = self.v(node).view(b, n, self.heads, c // self.heads) |
|
|
114 |
|
|
|
115 |
# Compute edge embeddings |
|
|
116 |
e_embed = self.e(edge).view(b, n, n, self.heads, c // self.heads) |
|
|
117 |
|
|
|
118 |
# Adjust dimensions for broadcasting: add singleton dimensions to queries and keys |
|
|
119 |
q_embed = q_embed.unsqueeze(2) # Shape: (b, n, 1, heads, c//heads) |
|
|
120 |
k_embed = k_embed.unsqueeze(1) # Shape: (b, 1, n, heads, c//heads) |
|
|
121 |
|
|
|
122 |
# Compute attention scores |
|
|
123 |
attn = q_embed * k_embed |
|
|
124 |
attn = attn / math.sqrt(self.d_k) |
|
|
125 |
attn = attn * (e_embed + 1) * e_embed # Modulated attention incorporating edge features |
|
|
126 |
|
|
|
127 |
edge_out = self.out_e(attn.flatten(3)) # Flatten last dimension for linear layer |
|
|
128 |
|
|
|
129 |
# Apply softmax over the node dimension to obtain normalized attention weights |
|
|
130 |
attn = F.softmax(attn, dim=2) |
|
|
131 |
|
|
|
132 |
v_embed = v_embed.unsqueeze(1) # Adjust dimensions to broadcast: (b, 1, n, heads, c//heads) |
|
|
133 |
v_embed = attn * v_embed |
|
|
134 |
v_embed = v_embed.sum(dim=2).flatten(2) |
|
|
135 |
node_out = self.out_n(v_embed) |
|
|
136 |
|
|
|
137 |
return node_out, edge_out |
|
|
138 |
|
|
|
139 |
class Encoder_Block(nn.Module): |
|
|
140 |
""" |
|
|
141 |
Transformer encoder block that integrates node and edge features. |
|
|
142 |
|
|
|
143 |
Consists of: |
|
|
144 |
- A multi-head attention layer with edge modulation. |
|
|
145 |
- Two MLP layers, each with residual connections and layer normalization. |
|
|
146 |
|
|
|
147 |
Attributes: |
|
|
148 |
ln1, ln3, ln4, ln5, ln6 (nn.LayerNorm): Layer normalization modules. |
|
|
149 |
attn (MHA): Multi-head attention module. |
|
|
150 |
mlp, mlp2 (MLP): MLP modules for further transformation of node and edge features. |
|
|
151 |
""" |
|
|
152 |
def __init__(self, dim, heads, act, mlp_ratio=4, drop_rate=0.): |
|
|
153 |
""" |
|
|
154 |
Initializes the encoder block. |
|
|
155 |
|
|
|
156 |
Args: |
|
|
157 |
dim (int): Dimensionality of the input features. |
|
|
158 |
heads (int): Number of attention heads. |
|
|
159 |
act (callable): Activation function (not explicitly used in this block, but provided for potential extensions). |
|
|
160 |
mlp_ratio (int, optional): Ratio to determine the hidden layer size in the MLP. Defaults to 4. |
|
|
161 |
drop_rate (float, optional): Dropout rate applied in the MLPs. Defaults to 0. |
|
|
162 |
""" |
|
|
163 |
super().__init__() |
|
|
164 |
|
|
|
165 |
self.ln1 = nn.LayerNorm(dim) |
|
|
166 |
self.attn = MHA(dim, heads, drop_rate) |
|
|
167 |
self.ln3 = nn.LayerNorm(dim) |
|
|
168 |
self.ln4 = nn.LayerNorm(dim) |
|
|
169 |
self.mlp = MLP(dim, dim * mlp_ratio, dim, dropout=drop_rate) |
|
|
170 |
self.mlp2 = MLP(dim, dim * mlp_ratio, dim, dropout=drop_rate) |
|
|
171 |
self.ln5 = nn.LayerNorm(dim) |
|
|
172 |
self.ln6 = nn.LayerNorm(dim) |
|
|
173 |
|
|
|
174 |
def forward(self, x, y): |
|
|
175 |
""" |
|
|
176 |
Forward pass of the encoder block. |
|
|
177 |
|
|
|
178 |
Args: |
|
|
179 |
x (torch.Tensor): Node feature tensor. |
|
|
180 |
y (torch.Tensor): Edge feature tensor. |
|
|
181 |
|
|
|
182 |
Returns: |
|
|
183 |
tuple: (updated node features, updated edge features) |
|
|
184 |
""" |
|
|
185 |
x1 = self.ln1(x) |
|
|
186 |
x2, y1 = self.attn(x1, y) |
|
|
187 |
x2 = x1 + x2 |
|
|
188 |
y2 = y + y1 |
|
|
189 |
x2 = self.ln3(x2) |
|
|
190 |
y2 = self.ln4(y2) |
|
|
191 |
x = self.ln5(x2 + self.mlp(x2)) |
|
|
192 |
y = self.ln6(y2 + self.mlp2(y2)) |
|
|
193 |
return x, y |
|
|
194 |
|
|
|
195 |
class TransformerEncoder(nn.Module): |
|
|
196 |
""" |
|
|
197 |
Transformer Encoder composed of a sequence of encoder blocks. |
|
|
198 |
|
|
|
199 |
Attributes: |
|
|
200 |
Encoder_Blocks (nn.ModuleList): A list of Encoder_Block modules stacked sequentially. |
|
|
201 |
""" |
|
|
202 |
def __init__(self, dim, depth, heads, act, mlp_ratio=4, drop_rate=0.1): |
|
|
203 |
""" |
|
|
204 |
Initializes the Transformer Encoder. |
|
|
205 |
|
|
|
206 |
Args: |
|
|
207 |
dim (int): Dimensionality of the input features. |
|
|
208 |
depth (int): Number of encoder blocks to stack. |
|
|
209 |
heads (int): Number of attention heads in each block. |
|
|
210 |
act (callable): Activation function (passed to encoder blocks for potential use). |
|
|
211 |
mlp_ratio (int, optional): Ratio for determining the hidden layer size in MLP modules. Defaults to 4. |
|
|
212 |
drop_rate (float, optional): Dropout rate for the MLPs within each block. Defaults to 0.1. |
|
|
213 |
""" |
|
|
214 |
super().__init__() |
|
|
215 |
|
|
|
216 |
self.Encoder_Blocks = nn.ModuleList([ |
|
|
217 |
Encoder_Block(dim, heads, act, mlp_ratio, drop_rate) |
|
|
218 |
for _ in range(depth) |
|
|
219 |
]) |
|
|
220 |
|
|
|
221 |
def forward(self, x, y): |
|
|
222 |
""" |
|
|
223 |
Forward pass of the Transformer Encoder. |
|
|
224 |
|
|
|
225 |
Args: |
|
|
226 |
x (torch.Tensor): Node feature tensor. |
|
|
227 |
y (torch.Tensor): Edge feature tensor. |
|
|
228 |
|
|
|
229 |
Returns: |
|
|
230 |
tuple: (final node features, final edge features) after processing through all encoder blocks. |
|
|
231 |
""" |
|
|
232 |
for block in self.Encoder_Blocks: |
|
|
233 |
x, y = block(x, y) |
|
|
234 |
return x, y |