Diff of /src/model/loss.py [000000] .. [7d53f6]

Switch to unified view

a b/src/model/loss.py
1
import torch
2
3
4
def gradient_penalty(discriminator, real_node, real_edge, fake_node, fake_edge, batch_size, device):
5
    """
6
    Calculate gradient penalty for WGAN-GP.
7
    
8
    Args:
9
        discriminator: The discriminator model
10
        real_node: Real node features
11
        real_edge: Real edge features
12
        fake_node: Generated node features
13
        fake_edge: Generated edge features
14
        batch_size: Batch size
15
        device: Device to compute on
16
        
17
    Returns:
18
        Gradient penalty term
19
    """
20
    # Generate random interpolation factors
21
    eps_edge = torch.rand(batch_size, 1, 1, 1, device=device)
22
    eps_node = torch.rand(batch_size, 1, 1, device=device)
23
    
24
    # Create interpolated samples
25
    int_node = (eps_node * real_node + (1 - eps_node) * fake_node).requires_grad_(True)
26
    int_edge = (eps_edge * real_edge + (1 - eps_edge) * fake_edge).requires_grad_(True)
27
28
    logits_interpolated = discriminator(int_edge, int_node)
29
30
    # Calculate gradients for both node and edge inputs
31
    weight = torch.ones(logits_interpolated.size(), requires_grad=False).to(device)
32
    gradients = torch.autograd.grad(
33
        outputs=logits_interpolated,
34
        inputs=[int_node, int_edge],
35
        grad_outputs=weight,
36
        create_graph=True,
37
        retain_graph=True,
38
        only_inputs=True
39
    )
40
41
    # Combine gradients from both inputs
42
    gradients_node = gradients[0].view(batch_size, -1)
43
    gradients_edge = gradients[1].view(batch_size, -1)
44
    gradients = torch.cat([gradients_node, gradients_edge], dim=1)
45
46
    # Calculate gradient penalty
47
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
48
    
49
    return gradient_penalty
50
51
52
def discriminator_loss(generator, discriminator, drug_adj, drug_annot, mol_adj, mol_annot, batch_size, device, lambda_gp):
53
    # Compute loss for drugs
54
    logits_real_disc = discriminator(drug_adj, drug_annot)
55
56
    # Use mean reduction for more stable training
57
    prediction_real = -torch.mean(logits_real_disc)
58
59
    # Compute loss for generated molecules
60
    node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot)
61
62
    logits_fake_disc = discriminator(edge_sample.detach(), node_sample.detach())
63
64
    prediction_fake = torch.mean(logits_fake_disc)
65
66
    # Compute gradient penalty using the new function
67
    gp = gradient_penalty(discriminator, drug_annot, drug_adj, node_sample.detach(), edge_sample.detach(), batch_size, device)
68
69
    # Calculate total discriminator loss
70
    d_loss = prediction_fake + prediction_real + lambda_gp * gp
71
72
    return node, edge, d_loss
73
74
75
def generator_loss(generator, discriminator, mol_adj, mol_annot, batch_size):
76
    # Generate fake molecules
77
    node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot)
78
79
    # Compute logits for fake molecules
80
    logits_fake_disc = discriminator(edge_sample, node_sample)
81
82
    prediction_fake = -torch.mean(logits_fake_disc)
83
    g_loss = prediction_fake
84
85
    return g_loss, node, edge, node_sample, edge_sample