|
a |
|
b/src/model/models.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
from src.model.layers import TransformerEncoder |
|
|
4 |
|
|
|
5 |
class Generator(nn.Module): |
|
|
6 |
""" |
|
|
7 |
Generator network that uses a Transformer Encoder to process node and edge features. |
|
|
8 |
|
|
|
9 |
The network first processes input node and edge features with separate linear layers, |
|
|
10 |
then applies a Transformer Encoder to model interactions, and finally outputs both transformed |
|
|
11 |
features and readout samples. |
|
|
12 |
""" |
|
|
13 |
def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio): |
|
|
14 |
""" |
|
|
15 |
Initializes the Generator. |
|
|
16 |
|
|
|
17 |
Args: |
|
|
18 |
act (str): Type of activation function to use ("relu", "leaky", "sigmoid", or "tanh"). |
|
|
19 |
vertexes (int): Number of vertexes in the graph. |
|
|
20 |
edges (int): Number of edge features. |
|
|
21 |
nodes (int): Number of node features. |
|
|
22 |
dropout (float): Dropout rate. |
|
|
23 |
dim (int): Dimensionality used for intermediate features. |
|
|
24 |
depth (int): Number of Transformer encoder blocks. |
|
|
25 |
heads (int): Number of attention heads in the Transformer. |
|
|
26 |
mlp_ratio (int): Ratio for determining hidden layer size in MLP modules. |
|
|
27 |
""" |
|
|
28 |
super(Generator, self).__init__() |
|
|
29 |
self.vertexes = vertexes |
|
|
30 |
self.edges = edges |
|
|
31 |
self.nodes = nodes |
|
|
32 |
self.depth = depth |
|
|
33 |
self.dim = dim |
|
|
34 |
self.heads = heads |
|
|
35 |
self.mlp_ratio = mlp_ratio |
|
|
36 |
self.dropout = dropout |
|
|
37 |
|
|
|
38 |
# Set the activation function based on the provided string |
|
|
39 |
if act == "relu": |
|
|
40 |
act = nn.ReLU() |
|
|
41 |
elif act == "leaky": |
|
|
42 |
act = nn.LeakyReLU() |
|
|
43 |
elif act == "sigmoid": |
|
|
44 |
act = nn.Sigmoid() |
|
|
45 |
elif act == "tanh": |
|
|
46 |
act = nn.Tanh() |
|
|
47 |
|
|
|
48 |
# Calculate the total number of features and dimensions for transformer |
|
|
49 |
self.features = vertexes * vertexes * edges + vertexes * nodes |
|
|
50 |
self.transformer_dim = vertexes * vertexes * dim + vertexes * dim |
|
|
51 |
|
|
|
52 |
self.node_layers = nn.Sequential( |
|
|
53 |
nn.Linear(nodes, 64), act, |
|
|
54 |
nn.Linear(64, dim), act, |
|
|
55 |
nn.Dropout(self.dropout) |
|
|
56 |
) |
|
|
57 |
self.edge_layers = nn.Sequential( |
|
|
58 |
nn.Linear(edges, 64), act, |
|
|
59 |
nn.Linear(64, dim), act, |
|
|
60 |
nn.Dropout(self.dropout) |
|
|
61 |
) |
|
|
62 |
self.TransformerEncoder = TransformerEncoder( |
|
|
63 |
dim=self.dim, depth=self.depth, heads=self.heads, act=act, |
|
|
64 |
mlp_ratio=self.mlp_ratio, drop_rate=self.dropout |
|
|
65 |
) |
|
|
66 |
|
|
|
67 |
self.readout_e = nn.Linear(self.dim, edges) |
|
|
68 |
self.readout_n = nn.Linear(self.dim, nodes) |
|
|
69 |
self.softmax = nn.Softmax(dim=-1) |
|
|
70 |
|
|
|
71 |
def forward(self, z_e, z_n): |
|
|
72 |
""" |
|
|
73 |
Forward pass of the Generator. |
|
|
74 |
|
|
|
75 |
Args: |
|
|
76 |
z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges). |
|
|
77 |
z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes). |
|
|
78 |
|
|
|
79 |
Returns: |
|
|
80 |
tuple: A tuple containing: |
|
|
81 |
- node: Updated node features after the transformer. |
|
|
82 |
- edge: Updated edge features after the transformer. |
|
|
83 |
- node_sample: Readout sample from node features. |
|
|
84 |
- edge_sample: Readout sample from edge features. |
|
|
85 |
""" |
|
|
86 |
b, n, c = z_n.shape |
|
|
87 |
# The fourth dimension of edge features |
|
|
88 |
_, _, _, d = z_e.shape |
|
|
89 |
|
|
|
90 |
# Process node and edge features through their respective layers |
|
|
91 |
node = self.node_layers(z_n) |
|
|
92 |
edge = self.edge_layers(z_e) |
|
|
93 |
# Symmetrize the edge features by averaging with its transpose along vertex dimensions |
|
|
94 |
edge = (edge + edge.permute(0, 2, 1, 3)) / 2 |
|
|
95 |
|
|
|
96 |
# Pass the features through the Transformer Encoder |
|
|
97 |
node, edge = self.TransformerEncoder(node, edge) |
|
|
98 |
|
|
|
99 |
# Readout layers to generate final outputs |
|
|
100 |
node_sample = self.readout_n(node) |
|
|
101 |
edge_sample = self.readout_e(edge) |
|
|
102 |
|
|
|
103 |
return node, edge, node_sample, edge_sample |
|
|
104 |
|
|
|
105 |
|
|
|
106 |
class Discriminator(nn.Module): |
|
|
107 |
""" |
|
|
108 |
Discriminator network that evaluates node and edge features. |
|
|
109 |
|
|
|
110 |
It processes features with linear layers, applies a Transformer Encoder to capture dependencies, |
|
|
111 |
and finally predicts a scalar value using an MLP on aggregated node features. |
|
|
112 |
|
|
|
113 |
This class is used in DrugGEN model. |
|
|
114 |
""" |
|
|
115 |
def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio): |
|
|
116 |
""" |
|
|
117 |
Initializes the Discriminator. |
|
|
118 |
|
|
|
119 |
Args: |
|
|
120 |
act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh"). |
|
|
121 |
vertexes (int): Number of vertexes. |
|
|
122 |
edges (int): Number of edge features. |
|
|
123 |
nodes (int): Number of node features. |
|
|
124 |
dropout (float): Dropout rate. |
|
|
125 |
dim (int): Dimensionality for intermediate representations. |
|
|
126 |
depth (int): Number of Transformer encoder blocks. |
|
|
127 |
heads (int): Number of attention heads. |
|
|
128 |
mlp_ratio (int): MLP ratio for hidden layer dimensions. |
|
|
129 |
""" |
|
|
130 |
super(Discriminator, self).__init__() |
|
|
131 |
self.vertexes = vertexes |
|
|
132 |
self.edges = edges |
|
|
133 |
self.nodes = nodes |
|
|
134 |
self.depth = depth |
|
|
135 |
self.dim = dim |
|
|
136 |
self.heads = heads |
|
|
137 |
self.mlp_ratio = mlp_ratio |
|
|
138 |
self.dropout = dropout |
|
|
139 |
|
|
|
140 |
# Set the activation function |
|
|
141 |
if act == "relu": |
|
|
142 |
act = nn.ReLU() |
|
|
143 |
elif act == "leaky": |
|
|
144 |
act = nn.LeakyReLU() |
|
|
145 |
elif act == "sigmoid": |
|
|
146 |
act = nn.Sigmoid() |
|
|
147 |
elif act == "tanh": |
|
|
148 |
act = nn.Tanh() |
|
|
149 |
|
|
|
150 |
self.features = vertexes * vertexes * edges + vertexes * nodes |
|
|
151 |
self.transformer_dim = vertexes * vertexes * dim + vertexes * dim |
|
|
152 |
|
|
|
153 |
# Define layers for processing node and edge features |
|
|
154 |
self.node_layers = nn.Sequential( |
|
|
155 |
nn.Linear(nodes, 64), act, |
|
|
156 |
nn.Linear(64, dim), act, |
|
|
157 |
nn.Dropout(self.dropout) |
|
|
158 |
) |
|
|
159 |
self.edge_layers = nn.Sequential( |
|
|
160 |
nn.Linear(edges, 64), act, |
|
|
161 |
nn.Linear(64, dim), act, |
|
|
162 |
nn.Dropout(self.dropout) |
|
|
163 |
) |
|
|
164 |
# Transformer Encoder for modeling node and edge interactions |
|
|
165 |
self.TransformerEncoder = TransformerEncoder( |
|
|
166 |
dim=self.dim, depth=self.depth, heads=self.heads, act=act, |
|
|
167 |
mlp_ratio=self.mlp_ratio, drop_rate=self.dropout |
|
|
168 |
) |
|
|
169 |
# Calculate dimensions for node features aggregation |
|
|
170 |
self.node_features = vertexes * dim |
|
|
171 |
self.edge_features = vertexes * vertexes * dim |
|
|
172 |
# MLP to predict a scalar value from aggregated node features |
|
|
173 |
self.node_mlp = nn.Sequential( |
|
|
174 |
nn.Linear(self.node_features, 64), act, |
|
|
175 |
nn.Linear(64, 32), act, |
|
|
176 |
nn.Linear(32, 16), act, |
|
|
177 |
nn.Linear(16, 1) |
|
|
178 |
) |
|
|
179 |
|
|
|
180 |
def forward(self, z_e, z_n): |
|
|
181 |
""" |
|
|
182 |
Forward pass of the Discriminator. |
|
|
183 |
|
|
|
184 |
Args: |
|
|
185 |
z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges). |
|
|
186 |
z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes). |
|
|
187 |
|
|
|
188 |
Returns: |
|
|
189 |
torch.Tensor: Prediction scores (typically a scalar per sample). |
|
|
190 |
""" |
|
|
191 |
b, n, c = z_n.shape |
|
|
192 |
# Unpack the shape of edge features (not used further directly) |
|
|
193 |
_, _, _, d = z_e.shape |
|
|
194 |
|
|
|
195 |
# Process node and edge features separately |
|
|
196 |
node = self.node_layers(z_n) |
|
|
197 |
edge = self.edge_layers(z_e) |
|
|
198 |
# Symmetrize edge features by averaging with its transpose |
|
|
199 |
edge = (edge + edge.permute(0, 2, 1, 3)) / 2 |
|
|
200 |
|
|
|
201 |
# Process features through the Transformer Encoder |
|
|
202 |
node, edge = self.TransformerEncoder(node, edge) |
|
|
203 |
|
|
|
204 |
# Flatten node features for MLP |
|
|
205 |
node = node.view(b, -1) |
|
|
206 |
# Predict a scalar score using the node MLP |
|
|
207 |
prediction = self.node_mlp(node) |
|
|
208 |
|
|
|
209 |
return prediction |
|
|
210 |
|
|
|
211 |
|
|
|
212 |
class simple_disc(nn.Module): |
|
|
213 |
""" |
|
|
214 |
A simplified discriminator that processes flattened features through an MLP |
|
|
215 |
to predict a scalar score. |
|
|
216 |
|
|
|
217 |
This class is used in NoTarget model. |
|
|
218 |
""" |
|
|
219 |
def __init__(self, act, m_dim, vertexes, b_dim): |
|
|
220 |
""" |
|
|
221 |
Initializes the simple discriminator. |
|
|
222 |
|
|
|
223 |
Args: |
|
|
224 |
act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh"). |
|
|
225 |
m_dim (int): Dimensionality for atom type features. |
|
|
226 |
vertexes (int): Number of vertexes. |
|
|
227 |
b_dim (int): Dimensionality for bond type features. |
|
|
228 |
""" |
|
|
229 |
super().__init__() |
|
|
230 |
|
|
|
231 |
# Set the activation function and check if it's supported |
|
|
232 |
if act == "relu": |
|
|
233 |
act = nn.ReLU() |
|
|
234 |
elif act == "leaky": |
|
|
235 |
act = nn.LeakyReLU() |
|
|
236 |
elif act == "sigmoid": |
|
|
237 |
act = nn.Sigmoid() |
|
|
238 |
elif act == "tanh": |
|
|
239 |
act = nn.Tanh() |
|
|
240 |
else: |
|
|
241 |
raise ValueError("Unsupported activation function: {}".format(act)) |
|
|
242 |
|
|
|
243 |
# Compute total number of features combining both dimensions |
|
|
244 |
features = vertexes * m_dim + vertexes * vertexes * b_dim |
|
|
245 |
print(vertexes) |
|
|
246 |
print(m_dim) |
|
|
247 |
print(b_dim) |
|
|
248 |
print(features) |
|
|
249 |
self.predictor = nn.Sequential( |
|
|
250 |
nn.Linear(features, 256), act, |
|
|
251 |
nn.Linear(256, 128), act, |
|
|
252 |
nn.Linear(128, 64), act, |
|
|
253 |
nn.Linear(64, 32), act, |
|
|
254 |
nn.Linear(32, 16), act, |
|
|
255 |
nn.Linear(16, 1) |
|
|
256 |
) |
|
|
257 |
|
|
|
258 |
def forward(self, x): |
|
|
259 |
""" |
|
|
260 |
Forward pass of the simple discriminator. |
|
|
261 |
|
|
|
262 |
Args: |
|
|
263 |
x (torch.Tensor): Input features tensor. |
|
|
264 |
|
|
|
265 |
Returns: |
|
|
266 |
torch.Tensor: Prediction scores. |
|
|
267 |
""" |
|
|
268 |
prediction = self.predictor(x) |
|
|
269 |
return prediction |