--- a +++ b/src/model/loss.py @@ -0,0 +1,85 @@ +import torch + + +def gradient_penalty(discriminator, real_node, real_edge, fake_node, fake_edge, batch_size, device): + """ + Calculate gradient penalty for WGAN-GP. + + Args: + discriminator: The discriminator model + real_node: Real node features + real_edge: Real edge features + fake_node: Generated node features + fake_edge: Generated edge features + batch_size: Batch size + device: Device to compute on + + Returns: + Gradient penalty term + """ + # Generate random interpolation factors + eps_edge = torch.rand(batch_size, 1, 1, 1, device=device) + eps_node = torch.rand(batch_size, 1, 1, device=device) + + # Create interpolated samples + int_node = (eps_node * real_node + (1 - eps_node) * fake_node).requires_grad_(True) + int_edge = (eps_edge * real_edge + (1 - eps_edge) * fake_edge).requires_grad_(True) + + logits_interpolated = discriminator(int_edge, int_node) + + # Calculate gradients for both node and edge inputs + weight = torch.ones(logits_interpolated.size(), requires_grad=False).to(device) + gradients = torch.autograd.grad( + outputs=logits_interpolated, + inputs=[int_node, int_edge], + grad_outputs=weight, + create_graph=True, + retain_graph=True, + only_inputs=True + ) + + # Combine gradients from both inputs + gradients_node = gradients[0].view(batch_size, -1) + gradients_edge = gradients[1].view(batch_size, -1) + gradients = torch.cat([gradients_node, gradients_edge], dim=1) + + # Calculate gradient penalty + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + + return gradient_penalty + + +def discriminator_loss(generator, discriminator, drug_adj, drug_annot, mol_adj, mol_annot, batch_size, device, lambda_gp): + # Compute loss for drugs + logits_real_disc = discriminator(drug_adj, drug_annot) + + # Use mean reduction for more stable training + prediction_real = -torch.mean(logits_real_disc) + + # Compute loss for generated molecules + node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot) + + logits_fake_disc = discriminator(edge_sample.detach(), node_sample.detach()) + + prediction_fake = torch.mean(logits_fake_disc) + + # Compute gradient penalty using the new function + gp = gradient_penalty(discriminator, drug_annot, drug_adj, node_sample.detach(), edge_sample.detach(), batch_size, device) + + # Calculate total discriminator loss + d_loss = prediction_fake + prediction_real + lambda_gp * gp + + return node, edge, d_loss + + +def generator_loss(generator, discriminator, mol_adj, mol_annot, batch_size): + # Generate fake molecules + node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot) + + # Compute logits for fake molecules + logits_fake_disc = discriminator(edge_sample, node_sample) + + prediction_fake = -torch.mean(logits_fake_disc) + g_loss = prediction_fake + + return g_loss, node, edge, node_sample, edge_sample \ No newline at end of file