--- a +++ b/src/model/models.py @@ -0,0 +1,269 @@ +import torch +import torch.nn as nn +from src.model.layers import TransformerEncoder + +class Generator(nn.Module): + """ + Generator network that uses a Transformer Encoder to process node and edge features. + + The network first processes input node and edge features with separate linear layers, + then applies a Transformer Encoder to model interactions, and finally outputs both transformed + features and readout samples. + """ + def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio): + """ + Initializes the Generator. + + Args: + act (str): Type of activation function to use ("relu", "leaky", "sigmoid", or "tanh"). + vertexes (int): Number of vertexes in the graph. + edges (int): Number of edge features. + nodes (int): Number of node features. + dropout (float): Dropout rate. + dim (int): Dimensionality used for intermediate features. + depth (int): Number of Transformer encoder blocks. + heads (int): Number of attention heads in the Transformer. + mlp_ratio (int): Ratio for determining hidden layer size in MLP modules. + """ + super(Generator, self).__init__() + self.vertexes = vertexes + self.edges = edges + self.nodes = nodes + self.depth = depth + self.dim = dim + self.heads = heads + self.mlp_ratio = mlp_ratio + self.dropout = dropout + + # Set the activation function based on the provided string + if act == "relu": + act = nn.ReLU() + elif act == "leaky": + act = nn.LeakyReLU() + elif act == "sigmoid": + act = nn.Sigmoid() + elif act == "tanh": + act = nn.Tanh() + + # Calculate the total number of features and dimensions for transformer + self.features = vertexes * vertexes * edges + vertexes * nodes + self.transformer_dim = vertexes * vertexes * dim + vertexes * dim + + self.node_layers = nn.Sequential( + nn.Linear(nodes, 64), act, + nn.Linear(64, dim), act, + nn.Dropout(self.dropout) + ) + self.edge_layers = nn.Sequential( + nn.Linear(edges, 64), act, + nn.Linear(64, dim), act, + nn.Dropout(self.dropout) + ) + self.TransformerEncoder = TransformerEncoder( + dim=self.dim, depth=self.depth, heads=self.heads, act=act, + mlp_ratio=self.mlp_ratio, drop_rate=self.dropout + ) + + self.readout_e = nn.Linear(self.dim, edges) + self.readout_n = nn.Linear(self.dim, nodes) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, z_e, z_n): + """ + Forward pass of the Generator. + + Args: + z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges). + z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes). + + Returns: + tuple: A tuple containing: + - node: Updated node features after the transformer. + - edge: Updated edge features after the transformer. + - node_sample: Readout sample from node features. + - edge_sample: Readout sample from edge features. + """ + b, n, c = z_n.shape + # The fourth dimension of edge features + _, _, _, d = z_e.shape + + # Process node and edge features through their respective layers + node = self.node_layers(z_n) + edge = self.edge_layers(z_e) + # Symmetrize the edge features by averaging with its transpose along vertex dimensions + edge = (edge + edge.permute(0, 2, 1, 3)) / 2 + + # Pass the features through the Transformer Encoder + node, edge = self.TransformerEncoder(node, edge) + + # Readout layers to generate final outputs + node_sample = self.readout_n(node) + edge_sample = self.readout_e(edge) + + return node, edge, node_sample, edge_sample + + +class Discriminator(nn.Module): + """ + Discriminator network that evaluates node and edge features. + + It processes features with linear layers, applies a Transformer Encoder to capture dependencies, + and finally predicts a scalar value using an MLP on aggregated node features. + + This class is used in DrugGEN model. + """ + def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio): + """ + Initializes the Discriminator. + + Args: + act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh"). + vertexes (int): Number of vertexes. + edges (int): Number of edge features. + nodes (int): Number of node features. + dropout (float): Dropout rate. + dim (int): Dimensionality for intermediate representations. + depth (int): Number of Transformer encoder blocks. + heads (int): Number of attention heads. + mlp_ratio (int): MLP ratio for hidden layer dimensions. + """ + super(Discriminator, self).__init__() + self.vertexes = vertexes + self.edges = edges + self.nodes = nodes + self.depth = depth + self.dim = dim + self.heads = heads + self.mlp_ratio = mlp_ratio + self.dropout = dropout + + # Set the activation function + if act == "relu": + act = nn.ReLU() + elif act == "leaky": + act = nn.LeakyReLU() + elif act == "sigmoid": + act = nn.Sigmoid() + elif act == "tanh": + act = nn.Tanh() + + self.features = vertexes * vertexes * edges + vertexes * nodes + self.transformer_dim = vertexes * vertexes * dim + vertexes * dim + + # Define layers for processing node and edge features + self.node_layers = nn.Sequential( + nn.Linear(nodes, 64), act, + nn.Linear(64, dim), act, + nn.Dropout(self.dropout) + ) + self.edge_layers = nn.Sequential( + nn.Linear(edges, 64), act, + nn.Linear(64, dim), act, + nn.Dropout(self.dropout) + ) + # Transformer Encoder for modeling node and edge interactions + self.TransformerEncoder = TransformerEncoder( + dim=self.dim, depth=self.depth, heads=self.heads, act=act, + mlp_ratio=self.mlp_ratio, drop_rate=self.dropout + ) + # Calculate dimensions for node features aggregation + self.node_features = vertexes * dim + self.edge_features = vertexes * vertexes * dim + # MLP to predict a scalar value from aggregated node features + self.node_mlp = nn.Sequential( + nn.Linear(self.node_features, 64), act, + nn.Linear(64, 32), act, + nn.Linear(32, 16), act, + nn.Linear(16, 1) + ) + + def forward(self, z_e, z_n): + """ + Forward pass of the Discriminator. + + Args: + z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges). + z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes). + + Returns: + torch.Tensor: Prediction scores (typically a scalar per sample). + """ + b, n, c = z_n.shape + # Unpack the shape of edge features (not used further directly) + _, _, _, d = z_e.shape + + # Process node and edge features separately + node = self.node_layers(z_n) + edge = self.edge_layers(z_e) + # Symmetrize edge features by averaging with its transpose + edge = (edge + edge.permute(0, 2, 1, 3)) / 2 + + # Process features through the Transformer Encoder + node, edge = self.TransformerEncoder(node, edge) + + # Flatten node features for MLP + node = node.view(b, -1) + # Predict a scalar score using the node MLP + prediction = self.node_mlp(node) + + return prediction + + +class simple_disc(nn.Module): + """ + A simplified discriminator that processes flattened features through an MLP + to predict a scalar score. + + This class is used in NoTarget model. + """ + def __init__(self, act, m_dim, vertexes, b_dim): + """ + Initializes the simple discriminator. + + Args: + act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh"). + m_dim (int): Dimensionality for atom type features. + vertexes (int): Number of vertexes. + b_dim (int): Dimensionality for bond type features. + """ + super().__init__() + + # Set the activation function and check if it's supported + if act == "relu": + act = nn.ReLU() + elif act == "leaky": + act = nn.LeakyReLU() + elif act == "sigmoid": + act = nn.Sigmoid() + elif act == "tanh": + act = nn.Tanh() + else: + raise ValueError("Unsupported activation function: {}".format(act)) + + # Compute total number of features combining both dimensions + features = vertexes * m_dim + vertexes * vertexes * b_dim + print(vertexes) + print(m_dim) + print(b_dim) + print(features) + self.predictor = nn.Sequential( + nn.Linear(features, 256), act, + nn.Linear(256, 128), act, + nn.Linear(128, 64), act, + nn.Linear(64, 32), act, + nn.Linear(32, 16), act, + nn.Linear(16, 1) + ) + + def forward(self, x): + """ + Forward pass of the simple discriminator. + + Args: + x (torch.Tensor): Input features tensor. + + Returns: + torch.Tensor: Prediction scores. + """ + prediction = self.predictor(x) + return prediction \ No newline at end of file