Diff of /src/model/models.py [000000] .. [7d53f6]

Switch to unified view

a b/src/model/models.py
1
import torch
2
import torch.nn as nn
3
from src.model.layers import TransformerEncoder
4
5
class Generator(nn.Module):
6
    """
7
    Generator network that uses a Transformer Encoder to process node and edge features.
8
    
9
    The network first processes input node and edge features with separate linear layers,
10
    then applies a Transformer Encoder to model interactions, and finally outputs both transformed
11
    features and readout samples.
12
    """
13
    def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
14
        """
15
        Initializes the Generator.
16
17
        Args:
18
            act (str): Type of activation function to use ("relu", "leaky", "sigmoid", or "tanh").
19
            vertexes (int): Number of vertexes in the graph.
20
            edges (int): Number of edge features.
21
            nodes (int): Number of node features.
22
            dropout (float): Dropout rate.
23
            dim (int): Dimensionality used for intermediate features.
24
            depth (int): Number of Transformer encoder blocks.
25
            heads (int): Number of attention heads in the Transformer.
26
            mlp_ratio (int): Ratio for determining hidden layer size in MLP modules.
27
        """
28
        super(Generator, self).__init__()
29
        self.vertexes = vertexes
30
        self.edges = edges
31
        self.nodes = nodes
32
        self.depth = depth
33
        self.dim = dim
34
        self.heads = heads
35
        self.mlp_ratio = mlp_ratio
36
        self.dropout = dropout
37
38
        # Set the activation function based on the provided string
39
        if act == "relu":
40
            act = nn.ReLU()
41
        elif act == "leaky":
42
            act = nn.LeakyReLU()
43
        elif act == "sigmoid":
44
            act = nn.Sigmoid()
45
        elif act == "tanh":
46
            act = nn.Tanh()
47
48
        # Calculate the total number of features and dimensions for transformer
49
        self.features = vertexes * vertexes * edges + vertexes * nodes
50
        self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
51
52
        self.node_layers = nn.Sequential(
53
            nn.Linear(nodes, 64), act,
54
            nn.Linear(64, dim), act,
55
            nn.Dropout(self.dropout)
56
        )
57
        self.edge_layers = nn.Sequential(
58
            nn.Linear(edges, 64), act,
59
            nn.Linear(64, dim), act,
60
            nn.Dropout(self.dropout)
61
        )
62
        self.TransformerEncoder = TransformerEncoder(
63
            dim=self.dim, depth=self.depth, heads=self.heads, act=act,
64
            mlp_ratio=self.mlp_ratio, drop_rate=self.dropout
65
        )
66
67
        self.readout_e = nn.Linear(self.dim, edges)
68
        self.readout_n = nn.Linear(self.dim, nodes)
69
        self.softmax = nn.Softmax(dim=-1)
70
71
    def forward(self, z_e, z_n):
72
        """
73
        Forward pass of the Generator.
74
        
75
        Args:
76
            z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges).
77
            z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes).
78
        
79
        Returns:
80
            tuple: A tuple containing:
81
                - node: Updated node features after the transformer.
82
                - edge: Updated edge features after the transformer.
83
                - node_sample: Readout sample from node features.
84
                - edge_sample: Readout sample from edge features.
85
        """
86
        b, n, c = z_n.shape
87
        # The fourth dimension of edge features
88
        _, _, _, d = z_e.shape
89
90
        # Process node and edge features through their respective layers
91
        node = self.node_layers(z_n)
92
        edge = self.edge_layers(z_e)
93
        # Symmetrize the edge features by averaging with its transpose along vertex dimensions
94
        edge = (edge + edge.permute(0, 2, 1, 3)) / 2
95
96
        # Pass the features through the Transformer Encoder
97
        node, edge = self.TransformerEncoder(node, edge)
98
99
        # Readout layers to generate final outputs
100
        node_sample = self.readout_n(node)
101
        edge_sample = self.readout_e(edge)
102
103
        return node, edge, node_sample, edge_sample
104
105
106
class Discriminator(nn.Module):
107
    """
108
    Discriminator network that evaluates node and edge features.
109
    
110
    It processes features with linear layers, applies a Transformer Encoder to capture dependencies,
111
    and finally predicts a scalar value using an MLP on aggregated node features.
112
113
    This class is used in DrugGEN model.
114
    """
115
    def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
116
        """
117
        Initializes the Discriminator.
118
119
        Args:
120
            act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh").
121
            vertexes (int): Number of vertexes.
122
            edges (int): Number of edge features.
123
            nodes (int): Number of node features.
124
            dropout (float): Dropout rate.
125
            dim (int): Dimensionality for intermediate representations.
126
            depth (int): Number of Transformer encoder blocks.
127
            heads (int): Number of attention heads.
128
            mlp_ratio (int): MLP ratio for hidden layer dimensions.
129
        """
130
        super(Discriminator, self).__init__()
131
        self.vertexes = vertexes
132
        self.edges = edges
133
        self.nodes = nodes
134
        self.depth = depth
135
        self.dim = dim
136
        self.heads = heads
137
        self.mlp_ratio = mlp_ratio
138
        self.dropout = dropout
139
140
        # Set the activation function
141
        if act == "relu":
142
            act = nn.ReLU()
143
        elif act == "leaky":
144
            act = nn.LeakyReLU()
145
        elif act == "sigmoid":
146
            act = nn.Sigmoid()
147
        elif act == "tanh":
148
            act = nn.Tanh()
149
150
        self.features = vertexes * vertexes * edges + vertexes * nodes
151
        self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
152
153
        # Define layers for processing node and edge features
154
        self.node_layers = nn.Sequential(
155
            nn.Linear(nodes, 64), act,
156
            nn.Linear(64, dim), act,
157
            nn.Dropout(self.dropout)
158
        )
159
        self.edge_layers = nn.Sequential(
160
            nn.Linear(edges, 64), act,
161
            nn.Linear(64, dim), act,
162
            nn.Dropout(self.dropout)
163
        )
164
        # Transformer Encoder for modeling node and edge interactions
165
        self.TransformerEncoder = TransformerEncoder(
166
            dim=self.dim, depth=self.depth, heads=self.heads, act=act,
167
            mlp_ratio=self.mlp_ratio, drop_rate=self.dropout
168
        )
169
        # Calculate dimensions for node features aggregation
170
        self.node_features = vertexes * dim
171
        self.edge_features = vertexes * vertexes * dim
172
        # MLP to predict a scalar value from aggregated node features
173
        self.node_mlp = nn.Sequential(
174
            nn.Linear(self.node_features, 64), act,
175
            nn.Linear(64, 32), act,
176
            nn.Linear(32, 16), act,
177
            nn.Linear(16, 1)
178
        )
179
180
    def forward(self, z_e, z_n):
181
        """
182
        Forward pass of the Discriminator.
183
        
184
        Args:
185
            z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges).
186
            z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes).
187
        
188
        Returns:
189
            torch.Tensor: Prediction scores (typically a scalar per sample).
190
        """
191
        b, n, c = z_n.shape
192
        # Unpack the shape of edge features (not used further directly)
193
        _, _, _, d = z_e.shape
194
195
        # Process node and edge features separately
196
        node = self.node_layers(z_n)
197
        edge = self.edge_layers(z_e)
198
        # Symmetrize edge features by averaging with its transpose
199
        edge = (edge + edge.permute(0, 2, 1, 3)) / 2
200
201
        # Process features through the Transformer Encoder
202
        node, edge = self.TransformerEncoder(node, edge)
203
204
        # Flatten node features for MLP
205
        node = node.view(b, -1)
206
        # Predict a scalar score using the node MLP
207
        prediction = self.node_mlp(node)
208
209
        return prediction
210
211
212
class simple_disc(nn.Module):
213
    """
214
    A simplified discriminator that processes flattened features through an MLP
215
    to predict a scalar score.
216
217
    This class is used in NoTarget model.
218
    """
219
    def __init__(self, act, m_dim, vertexes, b_dim):
220
        """
221
        Initializes the simple discriminator.
222
223
        Args:
224
            act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh").
225
            m_dim (int): Dimensionality for atom type features.
226
            vertexes (int): Number of vertexes.
227
            b_dim (int): Dimensionality for bond type features.
228
        """
229
        super().__init__()
230
231
        # Set the activation function and check if it's supported
232
        if act == "relu":
233
            act = nn.ReLU()
234
        elif act == "leaky":
235
            act = nn.LeakyReLU()
236
        elif act == "sigmoid":
237
            act = nn.Sigmoid()
238
        elif act == "tanh":
239
            act = nn.Tanh()
240
        else:
241
            raise ValueError("Unsupported activation function: {}".format(act))
242
243
        # Compute total number of features combining both dimensions
244
        features = vertexes * m_dim + vertexes * vertexes * b_dim
245
        print(vertexes)
246
        print(m_dim)
247
        print(b_dim)
248
        print(features)
249
        self.predictor = nn.Sequential(
250
            nn.Linear(features, 256), act,
251
            nn.Linear(256, 128), act,
252
            nn.Linear(128, 64), act,
253
            nn.Linear(64, 32), act,
254
            nn.Linear(32, 16), act,
255
            nn.Linear(16, 1)
256
        )
257
258
    def forward(self, x):
259
        """
260
        Forward pass of the simple discriminator.
261
        
262
        Args:
263
            x (torch.Tensor): Input features tensor.
264
        
265
        Returns:
266
            torch.Tensor: Prediction scores.
267
        """
268
        prediction = self.predictor(x)
269
        return prediction