|
a |
|
b/src/model/loss.py |
|
|
1 |
import torch |
|
|
2 |
|
|
|
3 |
|
|
|
4 |
def gradient_penalty(discriminator, real_node, real_edge, fake_node, fake_edge, batch_size, device): |
|
|
5 |
""" |
|
|
6 |
Calculate gradient penalty for WGAN-GP. |
|
|
7 |
|
|
|
8 |
Args: |
|
|
9 |
discriminator: The discriminator model |
|
|
10 |
real_node: Real node features |
|
|
11 |
real_edge: Real edge features |
|
|
12 |
fake_node: Generated node features |
|
|
13 |
fake_edge: Generated edge features |
|
|
14 |
batch_size: Batch size |
|
|
15 |
device: Device to compute on |
|
|
16 |
|
|
|
17 |
Returns: |
|
|
18 |
Gradient penalty term |
|
|
19 |
""" |
|
|
20 |
# Generate random interpolation factors |
|
|
21 |
eps_edge = torch.rand(batch_size, 1, 1, 1, device=device) |
|
|
22 |
eps_node = torch.rand(batch_size, 1, 1, device=device) |
|
|
23 |
|
|
|
24 |
# Create interpolated samples |
|
|
25 |
int_node = (eps_node * real_node + (1 - eps_node) * fake_node).requires_grad_(True) |
|
|
26 |
int_edge = (eps_edge * real_edge + (1 - eps_edge) * fake_edge).requires_grad_(True) |
|
|
27 |
|
|
|
28 |
logits_interpolated = discriminator(int_edge, int_node) |
|
|
29 |
|
|
|
30 |
# Calculate gradients for both node and edge inputs |
|
|
31 |
weight = torch.ones(logits_interpolated.size(), requires_grad=False).to(device) |
|
|
32 |
gradients = torch.autograd.grad( |
|
|
33 |
outputs=logits_interpolated, |
|
|
34 |
inputs=[int_node, int_edge], |
|
|
35 |
grad_outputs=weight, |
|
|
36 |
create_graph=True, |
|
|
37 |
retain_graph=True, |
|
|
38 |
only_inputs=True |
|
|
39 |
) |
|
|
40 |
|
|
|
41 |
# Combine gradients from both inputs |
|
|
42 |
gradients_node = gradients[0].view(batch_size, -1) |
|
|
43 |
gradients_edge = gradients[1].view(batch_size, -1) |
|
|
44 |
gradients = torch.cat([gradients_node, gradients_edge], dim=1) |
|
|
45 |
|
|
|
46 |
# Calculate gradient penalty |
|
|
47 |
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() |
|
|
48 |
|
|
|
49 |
return gradient_penalty |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
def discriminator_loss(generator, discriminator, drug_adj, drug_annot, mol_adj, mol_annot, batch_size, device, lambda_gp): |
|
|
53 |
# Compute loss for drugs |
|
|
54 |
logits_real_disc = discriminator(drug_adj, drug_annot) |
|
|
55 |
|
|
|
56 |
# Use mean reduction for more stable training |
|
|
57 |
prediction_real = -torch.mean(logits_real_disc) |
|
|
58 |
|
|
|
59 |
# Compute loss for generated molecules |
|
|
60 |
node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot) |
|
|
61 |
|
|
|
62 |
logits_fake_disc = discriminator(edge_sample.detach(), node_sample.detach()) |
|
|
63 |
|
|
|
64 |
prediction_fake = torch.mean(logits_fake_disc) |
|
|
65 |
|
|
|
66 |
# Compute gradient penalty using the new function |
|
|
67 |
gp = gradient_penalty(discriminator, drug_annot, drug_adj, node_sample.detach(), edge_sample.detach(), batch_size, device) |
|
|
68 |
|
|
|
69 |
# Calculate total discriminator loss |
|
|
70 |
d_loss = prediction_fake + prediction_real + lambda_gp * gp |
|
|
71 |
|
|
|
72 |
return node, edge, d_loss |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
def generator_loss(generator, discriminator, mol_adj, mol_annot, batch_size): |
|
|
76 |
# Generate fake molecules |
|
|
77 |
node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot) |
|
|
78 |
|
|
|
79 |
# Compute logits for fake molecules |
|
|
80 |
logits_fake_disc = discriminator(edge_sample, node_sample) |
|
|
81 |
|
|
|
82 |
prediction_fake = -torch.mean(logits_fake_disc) |
|
|
83 |
g_loss = prediction_fake |
|
|
84 |
|
|
|
85 |
return g_loss, node, edge, node_sample, edge_sample |