Switch to unified view

a b/equivariant_diffusion/dynamics.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
from equivariant_diffusion.egnn_new import EGNN, GNN
5
from equivariant_diffusion.en_diffusion import EnVariationalDiffusion
6
remove_mean_batch = EnVariationalDiffusion.remove_mean_batch
7
import numpy as np
8
9
10
class EGNNDynamics(nn.Module):
11
    def __init__(self, atom_nf, residue_nf,
12
                 n_dims, joint_nf=16, hidden_nf=64, device='cpu',
13
                 act_fn=torch.nn.SiLU(), n_layers=4, attention=False,
14
                 condition_time=True, tanh=False, mode='egnn_dynamics',
15
                 norm_constant=0, inv_sublayers=2, sin_embedding=False,
16
                 normalization_factor=100, aggregation_method='sum',
17
                 update_pocket_coords=True, edge_cutoff_ligand=None,
18
                 edge_cutoff_pocket=None, edge_cutoff_interaction=None,
19
                 reflection_equivariant=True, edge_embedding_dim=None):
20
        super().__init__()
21
        self.mode = mode
22
        self.edge_cutoff_l = edge_cutoff_ligand
23
        self.edge_cutoff_p = edge_cutoff_pocket
24
        self.edge_cutoff_i = edge_cutoff_interaction
25
        self.edge_nf = edge_embedding_dim
26
27
        self.atom_encoder = nn.Sequential(
28
            nn.Linear(atom_nf, 2 * atom_nf),
29
            act_fn,
30
            nn.Linear(2 * atom_nf, joint_nf)
31
        )
32
33
        self.atom_decoder = nn.Sequential(
34
            nn.Linear(joint_nf, 2 * atom_nf),
35
            act_fn,
36
            nn.Linear(2 * atom_nf, atom_nf)
37
        )
38
39
        self.residue_encoder = nn.Sequential(
40
            nn.Linear(residue_nf, 2 * residue_nf),
41
            act_fn,
42
            nn.Linear(2 * residue_nf, joint_nf)
43
        )
44
45
        self.residue_decoder = nn.Sequential(
46
            nn.Linear(joint_nf, 2 * residue_nf),
47
            act_fn,
48
            nn.Linear(2 * residue_nf, residue_nf)
49
        )
50
51
        self.edge_embedding = nn.Embedding(3, self.edge_nf) \
52
            if self.edge_nf is not None else None
53
        self.edge_nf = 0 if self.edge_nf is None else self.edge_nf
54
55
        if condition_time:
56
            dynamics_node_nf = joint_nf + 1
57
        else:
58
            print('Warning: dynamics model is _not_ conditioned on time.')
59
            dynamics_node_nf = joint_nf
60
61
        if mode == 'egnn_dynamics':
62
            self.egnn = EGNN(
63
                in_node_nf=dynamics_node_nf, in_edge_nf=self.edge_nf,
64
                hidden_nf=hidden_nf, device=device, act_fn=act_fn,
65
                n_layers=n_layers, attention=attention, tanh=tanh,
66
                norm_constant=norm_constant,
67
                inv_sublayers=inv_sublayers, sin_embedding=sin_embedding,
68
                normalization_factor=normalization_factor,
69
                aggregation_method=aggregation_method,
70
                reflection_equiv=reflection_equivariant
71
            )
72
            self.node_nf = dynamics_node_nf
73
            self.update_pocket_coords = update_pocket_coords
74
75
        elif mode == 'gnn_dynamics':
76
            self.gnn = GNN(
77
                in_node_nf=dynamics_node_nf + n_dims, in_edge_nf=self.edge_nf,
78
                hidden_nf=hidden_nf, out_node_nf=n_dims + dynamics_node_nf,
79
                device=device, act_fn=act_fn, n_layers=n_layers,
80
                attention=attention, normalization_factor=normalization_factor,
81
                aggregation_method=aggregation_method)
82
83
        self.device = device
84
        self.n_dims = n_dims
85
        self.condition_time = condition_time
86
87
    def forward(self, xh_atoms, xh_residues, t, mask_atoms, mask_residues):
88
89
        x_atoms = xh_atoms[:, :self.n_dims].clone()
90
        h_atoms = xh_atoms[:, self.n_dims:].clone()
91
92
        x_residues = xh_residues[:, :self.n_dims].clone()
93
        h_residues = xh_residues[:, self.n_dims:].clone()
94
95
        # embed atom features and residue features in a shared space
96
        h_atoms = self.atom_encoder(h_atoms)
97
        h_residues = self.residue_encoder(h_residues)
98
99
        # combine the two node types
100
        x = torch.cat((x_atoms, x_residues), dim=0)
101
        h = torch.cat((h_atoms, h_residues), dim=0)
102
        mask = torch.cat([mask_atoms, mask_residues])
103
104
        if self.condition_time:
105
            if np.prod(t.size()) == 1:
106
                # t is the same for all elements in batch.
107
                h_time = torch.empty_like(h[:, 0:1]).fill_(t.item())
108
            else:
109
                # t is different over the batch dimension.
110
                h_time = t[mask]
111
            h = torch.cat([h, h_time], dim=1)
112
113
        # get edges of a complete graph
114
        edges = self.get_edges(mask_atoms, mask_residues, x_atoms, x_residues)
115
        assert torch.all(mask[edges[0]] == mask[edges[1]])
116
117
        # Get edge types
118
        if self.edge_nf > 0:
119
            # 0: ligand-pocket, 1: ligand-ligand, 2: pocket-pocket
120
            edge_types = torch.zeros(edges.size(1), dtype=int, device=edges.device)
121
            edge_types[(edges[0] < len(mask_atoms)) & (edges[1] < len(mask_atoms))] = 1
122
            edge_types[(edges[0] >= len(mask_atoms)) & (edges[1] >= len(mask_atoms))] = 2
123
124
            # Learnable embedding
125
            edge_types = self.edge_embedding(edge_types)
126
        else:
127
            edge_types = None
128
129
        if self.mode == 'egnn_dynamics':
130
            update_coords_mask = None if self.update_pocket_coords \
131
                else torch.cat((torch.ones_like(mask_atoms),
132
                                torch.zeros_like(mask_residues))).unsqueeze(1)
133
            h_final, x_final = self.egnn(h, x, edges,
134
                                         update_coords_mask=update_coords_mask,
135
                                         batch_mask=mask, edge_attr=edge_types)
136
            vel = (x_final - x)
137
138
        elif self.mode == 'gnn_dynamics':
139
            xh = torch.cat([x, h], dim=1)
140
            output = self.gnn(xh, edges, node_mask=None, edge_attr=edge_types)
141
            vel = output[:, :3]
142
            h_final = output[:, 3:]
143
144
        else:
145
            raise Exception("Wrong mode %s" % self.mode)
146
147
        if self.condition_time:
148
            # Slice off last dimension which represented time.
149
            h_final = h_final[:, :-1]
150
151
        # decode atom and residue features
152
        h_final_atoms = self.atom_decoder(h_final[:len(mask_atoms)])
153
        h_final_residues = self.residue_decoder(h_final[len(mask_atoms):])
154
155
        if torch.any(torch.isnan(vel)):
156
            if self.training:
157
                vel[torch.isnan(vel)] = 0.0
158
            else:
159
                raise ValueError("NaN detected in EGNN output")
160
161
        if self.update_pocket_coords:
162
            # in case of unconditional joint distribution, include this as in
163
            # the original code
164
            vel = remove_mean_batch(vel, mask)
165
166
        return torch.cat([vel[:len(mask_atoms)], h_final_atoms], dim=-1), \
167
               torch.cat([vel[len(mask_atoms):], h_final_residues], dim=-1)
168
169
    def get_edges(self, batch_mask_ligand, batch_mask_pocket, x_ligand, x_pocket):
170
        adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :]
171
        adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :]
172
        adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :]
173
174
        if self.edge_cutoff_l is not None:
175
            adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l)
176
177
        if self.edge_cutoff_p is not None:
178
            adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p)
179
180
        if self.edge_cutoff_i is not None:
181
            adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i)
182
183
        adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1),
184
                         torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0)
185
        edges = torch.stack(torch.where(adj), dim=0)
186
187
        return edges