Diff of /lightning_modules.py [000000] .. [607087]

Switch to unified view

a b/lightning_modules.py
1
import math
2
from argparse import Namespace
3
from typing import Optional
4
from time import time
5
from pathlib import Path
6
7
import numpy as np
8
import torch
9
import torch.nn.functional as F
10
from torch.utils.data import DataLoader
11
import pytorch_lightning as pl
12
import wandb
13
from torch_scatter import scatter_add, scatter_mean
14
from Bio.PDB import PDBParser
15
from Bio.PDB.Polypeptide import three_to_one
16
17
from constants import dataset_params, FLOAT_TYPE, INT_TYPE
18
from equivariant_diffusion.dynamics import EGNNDynamics
19
from equivariant_diffusion.en_diffusion import EnVariationalDiffusion
20
from equivariant_diffusion.conditional_model import ConditionalDDPM, \
21
    SimpleConditionalDDPM
22
from dataset import ProcessedLigandPocketDataset
23
import utils
24
from analysis.visualization import save_xyz_file, visualize, visualize_chain
25
from analysis.metrics import BasicMolecularMetrics, CategoricalDistribution, \
26
    MoleculeProperties
27
from analysis.molecule_builder import build_molecule, process_molecule
28
from analysis.docking import smina_score
29
30
31
class LigandPocketDDPM(pl.LightningModule):
32
    def __init__(
33
            self,
34
            outdir,
35
            dataset,
36
            datadir,
37
            batch_size,
38
            lr,
39
            egnn_params: Namespace,
40
            diffusion_params,
41
            num_workers,
42
            augment_noise,
43
            augment_rotation,
44
            clip_grad,
45
            eval_epochs,
46
            eval_params,
47
            visualize_sample_epoch,
48
            visualize_chain_epoch,
49
            auxiliary_loss,
50
            loss_params,
51
            mode,
52
            node_histogram,
53
            pocket_representation='CA',
54
            virtual_nodes=False
55
    ):
56
        super(LigandPocketDDPM, self).__init__()
57
        self.save_hyperparameters()
58
59
        ddpm_models = {'joint': EnVariationalDiffusion,
60
                       'pocket_conditioning': ConditionalDDPM,
61
                       'pocket_conditioning_simple': SimpleConditionalDDPM}
62
        assert mode in ddpm_models
63
        self.mode = mode
64
        assert pocket_representation in {'CA', 'full-atom'}
65
        self.pocket_representation = pocket_representation
66
67
        self.dataset_name = dataset
68
        self.datadir = datadir
69
        self.outdir = outdir
70
        self.batch_size = batch_size
71
        self.eval_batch_size = eval_params.eval_batch_size \
72
            if 'eval_batch_size' in eval_params else batch_size
73
        self.lr = lr
74
        self.loss_type = diffusion_params.diffusion_loss_type
75
        self.eval_epochs = eval_epochs
76
        self.visualize_sample_epoch = visualize_sample_epoch
77
        self.visualize_chain_epoch = visualize_chain_epoch
78
        self.eval_params = eval_params
79
        self.num_workers = num_workers
80
        self.augment_noise = augment_noise
81
        self.augment_rotation = augment_rotation
82
        self.dataset_info = dataset_params[dataset]
83
        self.T = diffusion_params.diffusion_steps
84
        self.clip_grad = clip_grad
85
        if clip_grad:
86
            self.gradnorm_queue = utils.Queue()
87
            # Add large value that will be flushed.
88
            self.gradnorm_queue.add(3000)
89
90
        self.lig_type_encoder = self.dataset_info['atom_encoder']
91
        self.lig_type_decoder = self.dataset_info['atom_decoder']
92
        self.pocket_type_encoder = self.dataset_info['aa_encoder'] \
93
            if self.pocket_representation == 'CA' \
94
            else self.dataset_info['atom_encoder']
95
        self.pocket_type_decoder = self.dataset_info['aa_decoder'] \
96
            if self.pocket_representation == 'CA' \
97
            else self.dataset_info['atom_decoder']
98
99
        smiles_list = None if eval_params.smiles_file is None \
100
            else np.load(eval_params.smiles_file)
101
        self.ligand_metrics = BasicMolecularMetrics(self.dataset_info,
102
                                                    smiles_list)
103
        self.molecule_properties = MoleculeProperties()
104
        self.ligand_type_distribution = CategoricalDistribution(
105
            self.dataset_info['atom_hist'], self.lig_type_encoder)
106
        if self.pocket_representation == 'CA':
107
            self.pocket_type_distribution = CategoricalDistribution(
108
                self.dataset_info['aa_hist'], self.pocket_type_encoder)
109
        else:
110
            self.pocket_type_distribution = None
111
112
        self.train_dataset = None
113
        self.val_dataset = None
114
        self.test_dataset = None
115
116
        self.virtual_nodes = virtual_nodes
117
        self.data_transform = None
118
        self.max_num_nodes = len(node_histogram) - 1
119
        if virtual_nodes:
120
            # symbol = 'virtual'
121
            symbol = 'Ne'  # visualize as Neon atoms
122
            self.lig_type_encoder[symbol] = len(self.lig_type_encoder)
123
            self.virtual_atom = self.lig_type_encoder[symbol]
124
            self.lig_type_decoder.append(symbol)
125
            self.data_transform = utils.AppendVirtualNodes(
126
                self.max_num_nodes, self.lig_type_encoder, symbol)
127
128
            # Update dataset_info dictionary. This is necessary for using the
129
            # visualization functions.
130
            self.dataset_info['atom_encoder'] = self.lig_type_encoder
131
            self.dataset_info['atom_decoder'] = self.lig_type_decoder
132
133
        self.atom_nf = len(self.lig_type_decoder)
134
        self.aa_nf = len(self.pocket_type_decoder)
135
        self.x_dims = 3
136
137
        net_dynamics = EGNNDynamics(
138
            atom_nf=self.atom_nf,
139
            residue_nf=self.aa_nf,
140
            n_dims=self.x_dims,
141
            joint_nf=egnn_params.joint_nf,
142
            device=egnn_params.device if torch.cuda.is_available() else 'cpu',
143
            hidden_nf=egnn_params.hidden_nf,
144
            act_fn=torch.nn.SiLU(),
145
            n_layers=egnn_params.n_layers,
146
            attention=egnn_params.attention,
147
            tanh=egnn_params.tanh,
148
            norm_constant=egnn_params.norm_constant,
149
            inv_sublayers=egnn_params.inv_sublayers,
150
            sin_embedding=egnn_params.sin_embedding,
151
            normalization_factor=egnn_params.normalization_factor,
152
            aggregation_method=egnn_params.aggregation_method,
153
            edge_cutoff_ligand=egnn_params.__dict__.get('edge_cutoff_ligand'),
154
            edge_cutoff_pocket=egnn_params.__dict__.get('edge_cutoff_pocket'),
155
            edge_cutoff_interaction=egnn_params.__dict__.get('edge_cutoff_interaction'),
156
            update_pocket_coords=(self.mode == 'joint'),
157
            reflection_equivariant=egnn_params.reflection_equivariant,
158
            edge_embedding_dim=egnn_params.__dict__.get('edge_embedding_dim'),
159
        )
160
161
        self.ddpm = ddpm_models[self.mode](
162
                dynamics=net_dynamics,
163
                atom_nf=self.atom_nf,
164
                residue_nf=self.aa_nf,
165
                n_dims=self.x_dims,
166
                timesteps=diffusion_params.diffusion_steps,
167
                noise_schedule=diffusion_params.diffusion_noise_schedule,
168
                noise_precision=diffusion_params.diffusion_noise_precision,
169
                loss_type=diffusion_params.diffusion_loss_type,
170
                norm_values=diffusion_params.normalize_factors,
171
                size_histogram=node_histogram,
172
                virtual_node_idx=self.lig_type_encoder[symbol] if virtual_nodes else None
173
        )
174
175
        self.auxiliary_loss = auxiliary_loss
176
        self.lj_rm = self.dataset_info['lennard_jones_rm']
177
        if self.auxiliary_loss:
178
            self.clamp_lj = loss_params.clamp_lj
179
            self.auxiliary_weight_schedule = WeightSchedule(
180
                T=diffusion_params.diffusion_steps,
181
                max_weight=loss_params.max_weight, mode=loss_params.schedule)
182
183
    def configure_optimizers(self):
184
        return torch.optim.AdamW(self.ddpm.parameters(), lr=self.lr,
185
                                 amsgrad=True, weight_decay=1e-12)
186
187
    def setup(self, stage: Optional[str] = None):
188
        if stage == 'fit':
189
            self.train_dataset = ProcessedLigandPocketDataset(
190
                Path(self.datadir, 'train.npz'), transform=self.data_transform)
191
            self.val_dataset = ProcessedLigandPocketDataset(
192
                Path(self.datadir, 'val.npz'), transform=self.data_transform)
193
        elif stage == 'test':
194
            self.test_dataset = ProcessedLigandPocketDataset(
195
                Path(self.datadir, 'test.npz'), transform=self.data_transform)
196
        else:
197
            raise NotImplementedError
198
199
    def train_dataloader(self):
200
        return DataLoader(self.train_dataset, self.batch_size, shuffle=True,
201
                          num_workers=self.num_workers,
202
                          collate_fn=self.train_dataset.collate_fn,
203
                          pin_memory=True)
204
205
    def val_dataloader(self):
206
        return DataLoader(self.val_dataset, self.batch_size, shuffle=False,
207
                          num_workers=self.num_workers,
208
                          collate_fn=self.val_dataset.collate_fn,
209
                          pin_memory=True)
210
211
    def test_dataloader(self):
212
        return DataLoader(self.test_dataset, self.batch_size, shuffle=False,
213
                          num_workers=self.num_workers,
214
                          collate_fn=self.test_dataset.collate_fn,
215
                          pin_memory=True)
216
217
    def get_ligand_and_pocket(self, data):
218
        ligand = {
219
            'x': data['lig_coords'].to(self.device, FLOAT_TYPE),
220
            'one_hot': data['lig_one_hot'].to(self.device, FLOAT_TYPE),
221
            'size': data['num_lig_atoms'].to(self.device, INT_TYPE),
222
            'mask': data['lig_mask'].to(self.device, INT_TYPE),
223
        }
224
        if self.virtual_nodes:
225
            ligand['num_virtual_atoms'] = data['num_virtual_atoms'].to(
226
                self.device, INT_TYPE)
227
228
        pocket = {
229
            'x': data['pocket_coords'].to(self.device, FLOAT_TYPE),
230
            'one_hot': data['pocket_one_hot'].to(self.device, FLOAT_TYPE),
231
            'size': data['num_pocket_nodes'].to(self.device, INT_TYPE),
232
            'mask': data['pocket_mask'].to(self.device, INT_TYPE)
233
        }
234
        return ligand, pocket
235
236
    def forward(self, data):
237
        ligand, pocket = self.get_ligand_and_pocket(data)
238
239
        # Note: \mathcal{L} terms in the paper represent log-likelihoods while
240
        # our loss terms are a negative(!) log-likelihoods
241
        delta_log_px, error_t_lig, error_t_pocket, SNR_weight, \
242
        loss_0_x_ligand, loss_0_x_pocket, loss_0_h, neg_log_const_0, \
243
        kl_prior, log_pN, t_int, xh_lig_hat, info = \
244
            self.ddpm(ligand, pocket, return_info=True)
245
246
        if self.loss_type == 'l2' and self.training:
247
            actual_ligand_size = ligand['size'] - ligand['num_virtual_atoms'] if self.virtual_nodes else ligand['size']
248
249
            # normalize loss_t
250
            denom_lig = self.x_dims * actual_ligand_size + \
251
                        self.ddpm.atom_nf * ligand['size']
252
            error_t_lig = error_t_lig / denom_lig
253
            denom_pocket = (self.x_dims + self.ddpm.residue_nf) * pocket['size']
254
            error_t_pocket = error_t_pocket / denom_pocket
255
            loss_t = 0.5 * (error_t_lig + error_t_pocket)
256
257
            # normalize loss_0
258
            loss_0_x_ligand = loss_0_x_ligand / (self.x_dims * actual_ligand_size)
259
            loss_0_x_pocket = loss_0_x_pocket / (self.x_dims * pocket['size'])
260
            loss_0 = loss_0_x_ligand + loss_0_x_pocket + loss_0_h
261
262
        # VLB objective or evaluation step
263
        else:
264
            # Note: SNR_weight should be negative
265
            loss_t = -self.T * 0.5 * SNR_weight * (error_t_lig + error_t_pocket)
266
            loss_0 = loss_0_x_ligand + loss_0_x_pocket + loss_0_h
267
            loss_0 = loss_0 + neg_log_const_0
268
269
        nll = loss_t + loss_0 + kl_prior
270
271
        # Correct for normalization on x.
272
        if not (self.loss_type == 'l2' and self.training):
273
            nll = nll - delta_log_px
274
275
            # always the same number of nodes if virtual nodes are added
276
            if not self.virtual_nodes:
277
                # Transform conditional nll into joint nll
278
                # Note:
279
                # loss = -log p(x,h|N) and log p(x,h,N) = log p(x,h|N) + log p(N)
280
                # Therefore, log p(x,h|N) = -loss + log p(N)
281
                # => loss_new = -log p(x,h,N) = loss - log p(N)
282
                nll = nll - log_pN
283
284
        # Add auxiliary loss term
285
        if self.auxiliary_loss and self.loss_type == 'l2' and self.training:
286
            x_lig_hat = xh_lig_hat[:, :self.x_dims]
287
            h_lig_hat = xh_lig_hat[:, self.x_dims:]
288
            weighted_lj_potential = \
289
                self.auxiliary_weight_schedule(t_int.long()) * \
290
                self.lj_potential(x_lig_hat, h_lig_hat, ligand['mask'])
291
            nll = nll + weighted_lj_potential
292
            info['weighted_lj'] = weighted_lj_potential.mean(0)
293
294
        info['error_t_lig'] = error_t_lig.mean(0)
295
        info['error_t_pocket'] = error_t_pocket.mean(0)
296
        info['SNR_weight'] = SNR_weight.mean(0)
297
        info['loss_0'] = loss_0.mean(0)
298
        info['kl_prior'] = kl_prior.mean(0)
299
        info['delta_log_px'] = delta_log_px.mean(0)
300
        info['neg_log_const_0'] = neg_log_const_0.mean(0)
301
        info['log_pN'] = log_pN.mean(0)
302
        return nll, info
303
304
    def lj_potential(self, atom_x, atom_one_hot, batch_mask):
305
        adj = batch_mask[:, None] == batch_mask[None, :]
306
        adj = adj ^ torch.diag(torch.diag(adj))  # remove self-edges
307
        edges = torch.where(adj)
308
309
        # Compute pair-wise potentials
310
        r = torch.sum((atom_x[edges[0]] - atom_x[edges[1]])**2, dim=1).sqrt()
311
312
        # Get optimal radii
313
        lennard_jones_radii = torch.tensor(self.lj_rm, device=r.device)
314
        # unit conversion pm -> A
315
        lennard_jones_radii = lennard_jones_radii / 100.0
316
        # normalization
317
        lennard_jones_radii = lennard_jones_radii / self.ddpm.norm_values[0]
318
        atom_type_idx = atom_one_hot.argmax(1)
319
        rm = lennard_jones_radii[atom_type_idx[edges[0]],
320
                                 atom_type_idx[edges[1]]]
321
        sigma = 2 ** (-1 / 6) * rm
322
        out = 4 * ((sigma / r) ** 12 - (sigma / r) ** 6)
323
324
        if self.clamp_lj is not None:
325
            out = torch.clamp(out, min=None, max=self.clamp_lj)
326
327
        # Compute potential per atom
328
        out = scatter_add(out, edges[0], dim=0, dim_size=len(atom_x))
329
330
        # Sum potentials of all atoms
331
        return scatter_add(out, batch_mask, dim=0)
332
333
    def log_metrics(self, metrics_dict, split, batch_size=None, **kwargs):
334
        for m, value in metrics_dict.items():
335
            self.log(f'{m}/{split}', value, batch_size=batch_size, **kwargs)
336
337
    def training_step(self, data, *args):
338
        if self.augment_noise > 0:
339
            raise NotImplementedError
340
            # Add noise eps ~ N(0, augment_noise) around points.
341
            eps = sample_center_gravity_zero_gaussian(x.size(), x.device)
342
            x = x + eps * args.augment_noise
343
344
        if self.augment_rotation:
345
            raise NotImplementedError
346
            x = utils.random_rotation(x).detach()
347
348
        try:
349
            nll, info = self.forward(data)
350
        except RuntimeError as e:
351
            # this is not supported for multi-GPU
352
            if self.trainer.num_devices < 2 and 'out of memory' in str(e):
353
                print('WARNING: ran out of memory, skipping to the next batch')
354
                return None
355
            else:
356
                raise e
357
358
        loss = nll.mean(0)
359
360
        info['loss'] = loss
361
        self.log_metrics(info, 'train', batch_size=len(data['num_lig_atoms']))
362
363
        return info
364
365
    def _shared_eval(self, data, prefix, *args):
366
        nll, info = self.forward(data)
367
        loss = nll.mean(0)
368
369
        info['loss'] = loss
370
371
        self.log_metrics(info, prefix, batch_size=len(data['num_lig_atoms']),
372
                         sync_dist=True)
373
374
        return info
375
376
    def validation_step(self, data, *args):
377
        self._shared_eval(data, 'val', *args)
378
379
    def test_step(self, data, *args):
380
        self._shared_eval(data, 'test', *args)
381
382
    def validation_epoch_end(self, validation_step_outputs):
383
384
        # Perform validation on single GPU
385
        if not self.trainer.is_global_zero:
386
            return
387
388
        suffix = '' if self.mode == 'joint' else '_given_pocket'
389
390
        if (self.current_epoch + 1) % self.eval_epochs == 0:
391
            tic = time()
392
393
            sampling_results = getattr(self, 'sample_and_analyze' + suffix)(
394
                self.eval_params.n_eval_samples, self.val_dataset,
395
                batch_size=self.eval_batch_size)
396
            self.log_metrics(sampling_results, 'val')
397
398
            print(f'Evaluation took {time() - tic:.2f} seconds')
399
400
        if (self.current_epoch + 1) % self.visualize_sample_epoch == 0:
401
            tic = time()
402
            getattr(self, 'sample_and_save' + suffix)(
403
                self.eval_params.n_visualize_samples)
404
            print(f'Sample visualization took {time() - tic:.2f} seconds')
405
406
        if (self.current_epoch + 1) % self.visualize_chain_epoch == 0:
407
            tic = time()
408
            getattr(self, 'sample_chain_and_save' + suffix)(
409
                self.eval_params.keep_frames)
410
            print(f'Chain visualization took {time() - tic:.2f} seconds')
411
412
    @torch.no_grad()
413
    def sample_and_analyze(self, n_samples, dataset=None, batch_size=None):
414
        print(f'Analyzing sampled molecules at epoch {self.current_epoch}...')
415
416
        batch_size = self.batch_size if batch_size is None else batch_size
417
        batch_size = min(batch_size, n_samples)
418
419
        # each item in molecules is a tuple (position, atom_type_encoded)
420
        molecules = []
421
        atom_types = []
422
        aa_types = []
423
        for i in range(math.ceil(n_samples / batch_size)):
424
425
            n_samples_batch = min(batch_size, n_samples - len(molecules))
426
427
            num_nodes_lig, num_nodes_pocket = \
428
                self.ddpm.size_distribution.sample(n_samples_batch)
429
430
            xh_lig, xh_pocket, lig_mask, _ = self.ddpm.sample(
431
                n_samples_batch, num_nodes_lig, num_nodes_pocket,
432
                device=self.device)
433
434
            x = xh_lig[:, :self.x_dims].detach().cpu()
435
            atom_type = xh_lig[:, self.x_dims:].argmax(1).detach().cpu()
436
            lig_mask = lig_mask.cpu()
437
438
            molecules.extend(list(
439
                zip(utils.batch_to_list(x, lig_mask),
440
                    utils.batch_to_list(atom_type, lig_mask))
441
            ))
442
443
            atom_types.extend(atom_type.tolist())
444
            aa_types.extend(
445
                xh_pocket[:, self.x_dims:].argmax(1).detach().cpu().tolist())
446
447
        return self.analyze_sample(molecules, atom_types, aa_types)
448
449
    def analyze_sample(self, molecules, atom_types, aa_types, receptors=None):
450
        # Distribution of node types
451
        kl_div_atom = self.ligand_type_distribution.kl_divergence(atom_types) \
452
            if self.ligand_type_distribution is not None else -1
453
        kl_div_aa = self.pocket_type_distribution.kl_divergence(aa_types) \
454
            if self.pocket_type_distribution is not None else -1
455
456
        # Convert into rdmols
457
        rdmols = [build_molecule(*graph, self.dataset_info) for graph in molecules]
458
459
        # Other basic metrics
460
        (validity, connectivity, uniqueness, novelty), (_, connected_mols) = \
461
            self.ligand_metrics.evaluate_rdmols(rdmols)
462
463
        qed, sa, logp, lipinski, diversity = \
464
            self.molecule_properties.evaluate_mean(connected_mols)
465
466
        out = {
467
            'kl_div_atom_types': kl_div_atom,
468
            'kl_div_residue_types': kl_div_aa,
469
            'Validity': validity,
470
            'Connectivity': connectivity,
471
            'Uniqueness': uniqueness,
472
            'Novelty': novelty,
473
            'QED': qed,
474
            'SA': sa,
475
            'LogP': logp,
476
            'Lipinski': lipinski,
477
            'Diversity': diversity
478
        }
479
480
        # Simple docking score
481
        if receptors is not None:
482
            # out['smina_score'] = np.mean(smina_score(rdmols, receptors))
483
            out['smina_score'] = np.mean(smina_score(connected_mols, receptors))
484
485
        return out
486
487
    def get_full_path(self, receptor_name):
488
        pdb, suffix = receptor_name.split('.')
489
        receptor_name = f'{pdb.upper()}-{suffix}.pdb'
490
        return Path(self.datadir, 'val', receptor_name)
491
492
    @torch.no_grad()
493
    def sample_and_analyze_given_pocket(self, n_samples, dataset=None,
494
                                        batch_size=None):
495
        print(f'Analyzing sampled molecules given pockets at epoch '
496
              f'{self.current_epoch}...')
497
498
        batch_size = self.batch_size if batch_size is None else batch_size
499
        batch_size = min(batch_size, n_samples)
500
501
        # each item in molecules is a tuple (position, atom_type_encoded)
502
        molecules = []
503
        atom_types = []
504
        aa_types = []
505
        receptors = []
506
        for i in range(math.ceil(n_samples / batch_size)):
507
508
            n_samples_batch = min(batch_size, n_samples - len(molecules))
509
510
            # Create a batch
511
            batch = dataset.collate_fn(
512
                [dataset[(i * batch_size + j) % len(dataset)]
513
                 for j in range(n_samples_batch)]
514
            )
515
516
            ligand, pocket = self.get_ligand_and_pocket(batch)
517
            receptors.extend([self.get_full_path(x) for x in batch['receptors']])
518
519
            if self.virtual_nodes:
520
                num_nodes_lig = self.max_num_nodes
521
            else:
522
                num_nodes_lig = self.ddpm.size_distribution.sample_conditional(
523
                    n1=None, n2=pocket['size'])
524
525
            xh_lig, xh_pocket, lig_mask, _ = self.ddpm.sample_given_pocket(
526
                pocket, num_nodes_lig)
527
528
            x = xh_lig[:, :self.x_dims].detach().cpu()
529
            atom_type = xh_lig[:, self.x_dims:].argmax(1).detach().cpu()
530
            lig_mask = lig_mask.cpu()
531
532
            if self.virtual_nodes:
533
                # Remove virtual nodes for analysis
534
                vnode_mask = (atom_type == self.virtual_atom)
535
                x = x[~vnode_mask, :]
536
                atom_type = atom_type[~vnode_mask]
537
                lig_mask = lig_mask[~vnode_mask]
538
539
            molecules.extend(list(
540
                zip(utils.batch_to_list(x, lig_mask),
541
                    utils.batch_to_list(atom_type, lig_mask))
542
            ))
543
544
            atom_types.extend(atom_type.tolist())
545
            aa_types.extend(
546
                xh_pocket[:, self.x_dims:].argmax(1).detach().cpu().tolist())
547
548
        return self.analyze_sample(molecules, atom_types, aa_types,
549
                                   receptors=receptors)
550
551
    def sample_and_save(self, n_samples):
552
        num_nodes_lig, num_nodes_pocket = \
553
            self.ddpm.size_distribution.sample(n_samples)
554
555
        xh_lig, xh_pocket, lig_mask, pocket_mask = \
556
            self.ddpm.sample(n_samples, num_nodes_lig, num_nodes_pocket,
557
                             device=self.device)
558
559
        if self.pocket_representation == 'CA':
560
            # convert residues into atom representation for visualization
561
            x_pocket, one_hot_pocket = utils.residues_to_atoms(
562
                xh_pocket[:, :self.x_dims], self.lig_type_encoder)
563
        else:
564
            x_pocket, one_hot_pocket = \
565
                xh_pocket[:, :self.x_dims], xh_pocket[:, self.x_dims:]
566
        x = torch.cat((xh_lig[:, :self.x_dims], x_pocket), dim=0)
567
        one_hot = torch.cat((xh_lig[:, self.x_dims:], one_hot_pocket), dim=0)
568
569
        outdir = Path(self.outdir, f'epoch_{self.current_epoch}')
570
        save_xyz_file(str(outdir) + '/', one_hot, x, self.lig_type_decoder,
571
                      name='molecule',
572
                      batch_mask=torch.cat((lig_mask, pocket_mask)))
573
        # visualize(str(outdir), dataset_info=self.dataset_info, wandb=wandb)
574
        visualize(str(outdir), dataset_info=self.dataset_info, wandb=None)
575
576
    def sample_and_save_given_pocket(self, n_samples):
577
        batch = self.val_dataset.collate_fn(
578
            [self.val_dataset[i] for i in torch.randint(len(self.val_dataset),
579
                                                        size=(n_samples,))]
580
        )
581
        ligand, pocket = self.get_ligand_and_pocket(batch)
582
583
        if self.virtual_nodes:
584
            num_nodes_lig = self.max_num_nodes
585
        else:
586
            num_nodes_lig = self.ddpm.size_distribution.sample_conditional(
587
                n1=None, n2=pocket['size'])
588
589
        xh_lig, xh_pocket, lig_mask, pocket_mask = \
590
            self.ddpm.sample_given_pocket(pocket, num_nodes_lig)
591
592
        if self.pocket_representation == 'CA':
593
            # convert residues into atom representation for visualization
594
            x_pocket, one_hot_pocket = utils.residues_to_atoms(
595
                xh_pocket[:, :self.x_dims], self.lig_type_encoder)
596
        else:
597
            x_pocket, one_hot_pocket = \
598
                xh_pocket[:, :self.x_dims], xh_pocket[:, self.x_dims:]
599
        x = torch.cat((xh_lig[:, :self.x_dims], x_pocket), dim=0)
600
        one_hot = torch.cat((xh_lig[:, self.x_dims:], one_hot_pocket), dim=0)
601
602
        outdir = Path(self.outdir, f'epoch_{self.current_epoch}')
603
        save_xyz_file(str(outdir) + '/', one_hot, x, self.lig_type_decoder,
604
                      name='molecule',
605
                      batch_mask=torch.cat((lig_mask, pocket_mask)))
606
        # visualize(str(outdir), dataset_info=self.dataset_info, wandb=wandb)
607
        visualize(str(outdir), dataset_info=self.dataset_info, wandb=None)
608
609
    def sample_chain_and_save(self, keep_frames):
610
        n_samples = 1
611
612
        num_nodes_lig, num_nodes_pocket = \
613
            self.ddpm.size_distribution.sample(n_samples)
614
615
        chain_lig, chain_pocket, _, _ = self.ddpm.sample(
616
            n_samples, num_nodes_lig, num_nodes_pocket,
617
            return_frames=keep_frames, device=self.device)
618
619
        chain_lig = utils.reverse_tensor(chain_lig)
620
        chain_pocket = utils.reverse_tensor(chain_pocket)
621
622
        # Repeat last frame to see final sample better.
623
        chain_lig = torch.cat([chain_lig, chain_lig[-1:].repeat(10, 1, 1)],
624
                              dim=0)
625
        chain_pocket = torch.cat(
626
            [chain_pocket, chain_pocket[-1:].repeat(10, 1, 1)], dim=0)
627
628
        # Prepare entire chain.
629
        x_lig = chain_lig[:, :, :self.x_dims]
630
        one_hot_lig = chain_lig[:, :, self.x_dims:]
631
        one_hot_lig = F.one_hot(
632
            torch.argmax(one_hot_lig, dim=2),
633
            num_classes=len(self.lig_type_decoder))
634
        x_pocket = chain_pocket[:, :, :self.x_dims]
635
        one_hot_pocket = chain_pocket[:, :, self.x_dims:]
636
        one_hot_pocket = F.one_hot(
637
            torch.argmax(one_hot_pocket, dim=2),
638
            num_classes=len(self.pocket_type_decoder))
639
640
        if self.pocket_representation == 'CA':
641
            # convert residues into atom representation for visualization
642
            x_pocket, one_hot_pocket = utils.residues_to_atoms(
643
                x_pocket, self.lig_type_encoder)
644
645
        x = torch.cat((x_lig, x_pocket), dim=1)
646
        one_hot = torch.cat((one_hot_lig, one_hot_pocket), dim=1)
647
648
        # flatten (treat frame (chain dimension) as batch for visualization)
649
        x_flat = x.view(-1, x.size(-1))
650
        one_hot_flat = one_hot.view(-1, one_hot.size(-1))
651
        mask_flat = torch.arange(x.size(0)).repeat_interleave(x.size(1))
652
653
        outdir = Path(self.outdir, f'epoch_{self.current_epoch}', 'chain')
654
        save_xyz_file(str(outdir), one_hot_flat, x_flat, self.lig_type_decoder,
655
                      name='/chain', batch_mask=mask_flat)
656
        visualize_chain(str(outdir), self.dataset_info, wandb=wandb)
657
658
    def sample_chain_and_save_given_pocket(self, keep_frames):
659
        n_samples = 1
660
661
        batch = self.val_dataset.collate_fn([
662
            self.val_dataset[torch.randint(len(self.val_dataset), size=(1,))]
663
        ])
664
        ligand, pocket = self.get_ligand_and_pocket(batch)
665
666
        if self.virtual_nodes:
667
            num_nodes_lig = self.max_num_nodes
668
        else:
669
            num_nodes_lig = self.ddpm.size_distribution.sample_conditional(
670
                n1=None, n2=pocket['size'])
671
672
        chain_lig, chain_pocket, _, _ = self.ddpm.sample_given_pocket(
673
            pocket, num_nodes_lig, return_frames=keep_frames)
674
675
        chain_lig = utils.reverse_tensor(chain_lig)
676
        chain_pocket = utils.reverse_tensor(chain_pocket)
677
678
        # Repeat last frame to see final sample better.
679
        chain_lig = torch.cat([chain_lig, chain_lig[-1:].repeat(10, 1, 1)],
680
                              dim=0)
681
        chain_pocket = torch.cat(
682
            [chain_pocket, chain_pocket[-1:].repeat(10, 1, 1)], dim=0)
683
684
        # Prepare entire chain.
685
        x_lig = chain_lig[:, :, :self.x_dims]
686
        one_hot_lig = chain_lig[:, :, self.x_dims:]
687
        one_hot_lig = F.one_hot(
688
            torch.argmax(one_hot_lig, dim=2),
689
            num_classes=len(self.lig_type_decoder))
690
        x_pocket = chain_pocket[:, :, :3]
691
        one_hot_pocket = chain_pocket[:, :, 3:]
692
        one_hot_pocket = F.one_hot(
693
            torch.argmax(one_hot_pocket, dim=2),
694
            num_classes=len(self.pocket_type_decoder))
695
696
        if self.pocket_representation == 'CA':
697
            # convert residues into atom representation for visualization
698
            x_pocket, one_hot_pocket = utils.residues_to_atoms(
699
                x_pocket, self.lig_type_encoder)
700
701
        x = torch.cat((x_lig, x_pocket), dim=1)
702
        one_hot = torch.cat((one_hot_lig, one_hot_pocket), dim=1)
703
704
        # flatten (treat frame (chain dimension) as batch for visualization)
705
        x_flat = x.view(-1, x.size(-1))
706
        one_hot_flat = one_hot.view(-1, one_hot.size(-1))
707
        mask_flat = torch.arange(x.size(0)).repeat_interleave(x.size(1))
708
709
        outdir = Path(self.outdir, f'epoch_{self.current_epoch}', 'chain')
710
        save_xyz_file(str(outdir), one_hot_flat, x_flat, self.lig_type_decoder,
711
                      name='/chain', batch_mask=mask_flat)
712
        visualize_chain(str(outdir), self.dataset_info, wandb=wandb)
713
714
    def prepare_pocket(self, biopython_residues, repeats=1):
715
716
        if self.pocket_representation == 'CA':
717
            pocket_coord = torch.tensor(np.array(
718
                [res['CA'].get_coord() for res in biopython_residues]),
719
                device=self.device, dtype=FLOAT_TYPE)
720
            pocket_types = torch.tensor(
721
                [self.pocket_type_encoder[three_to_one(res.get_resname())]
722
                 for res in biopython_residues], device=self.device)
723
        else:
724
            pocket_atoms = [a for res in biopython_residues
725
                            for a in res.get_atoms()
726
                            if (a.element.capitalize() in self.pocket_type_encoder or a.element != 'H')]
727
            pocket_coord = torch.tensor(np.array(
728
                [a.get_coord() for a in pocket_atoms]),
729
                device=self.device, dtype=FLOAT_TYPE)
730
            pocket_types = torch.tensor(
731
                [self.pocket_type_encoder[a.element.capitalize()]
732
                 for a in pocket_atoms], device=self.device)
733
734
        pocket_one_hot = F.one_hot(
735
            pocket_types, num_classes=len(self.pocket_type_encoder)
736
        )
737
738
        pocket_size = torch.tensor([len(pocket_coord)] * repeats,
739
                                   device=self.device, dtype=INT_TYPE)
740
        pocket_mask = torch.repeat_interleave(
741
            torch.arange(repeats, device=self.device, dtype=INT_TYPE),
742
            len(pocket_coord)
743
        )
744
745
        pocket = {
746
            'x': pocket_coord.repeat(repeats, 1),
747
            'one_hot': pocket_one_hot.repeat(repeats, 1),
748
            'size': pocket_size,
749
            'mask': pocket_mask
750
        }
751
752
        return pocket
753
754
    def generate_ligands(self, pdb_file, n_samples, pocket_ids=None,
755
                         ref_ligand=None, num_nodes_lig=None, sanitize=False,
756
                         largest_frag=False, relax_iter=0, timesteps=None,
757
                         n_nodes_bias=0, n_nodes_min=0, **kwargs):
758
        """
759
        Generate ligands given a pocket
760
        Args:
761
            pdb_file: PDB filename
762
            n_samples: number of samples
763
            pocket_ids: list of pocket residues in <chain>:<resi> format
764
            ref_ligand: alternative way of defining the pocket based on a
765
                reference ligand given in <chain>:<resi> format if the ligand is
766
                contained in the PDB file, or path to an SDF file that
767
                contains the ligand
768
            num_nodes_lig: number of ligand nodes for each sample (list of
769
                integers), sampled randomly if 'None'
770
            sanitize: whether to sanitize molecules or not
771
            largest_frag: only return the largest fragment
772
            relax_iter: number of force field optimization steps
773
            timesteps: number of denoising steps, use training value if None
774
            n_nodes_bias: added to the sampled (or provided) number of nodes
775
            n_nodes_min: lower bound on the number of sampled nodes
776
            kwargs: additional inpainting parameters
777
        Returns:
778
            list of molecules
779
        """
780
781
        assert (pocket_ids is None) ^ (ref_ligand is None)
782
783
        self.ddpm.eval()
784
785
        # Load PDB
786
        pdb_struct = PDBParser(QUIET=True).get_structure('', pdb_file)[0]
787
        if pocket_ids is not None:
788
            # define pocket with list of residues
789
            residues = [
790
                pdb_struct[x.split(':')[0]][(' ', int(x.split(':')[1]), ' ')]
791
                for x in pocket_ids]
792
793
        else:
794
            # define pocket with reference ligand
795
            residues = utils.get_pocket_from_ligand(pdb_struct, ref_ligand)
796
797
        pocket = self.prepare_pocket(residues, repeats=n_samples)
798
799
        # Pocket's center of mass
800
        pocket_com_before = scatter_mean(pocket['x'], pocket['mask'], dim=0)
801
802
        # Create dummy ligands
803
        if num_nodes_lig is None:
804
            num_nodes_lig = self.ddpm.size_distribution.sample_conditional(
805
                n1=None, n2=pocket['size'])
806
807
        # Add bias
808
        num_nodes_lig = num_nodes_lig + n_nodes_bias
809
810
        # Apply minimum ligand size
811
        num_nodes_lig = torch.clamp(num_nodes_lig, min=n_nodes_min)
812
813
        # Use inpainting
814
        if type(self.ddpm) == EnVariationalDiffusion:
815
            lig_mask = utils.num_nodes_to_batch_mask(
816
                len(num_nodes_lig), num_nodes_lig, self.device)
817
818
            ligand = {
819
                'x': torch.zeros((len(lig_mask), self.x_dims),
820
                                 device=self.device, dtype=FLOAT_TYPE),
821
                'one_hot': torch.zeros((len(lig_mask), self.atom_nf),
822
                                       device=self.device, dtype=FLOAT_TYPE),
823
                'size': num_nodes_lig,
824
                'mask': lig_mask
825
            }
826
827
            # Fix all pocket nodes but sample
828
            lig_mask_fixed = torch.zeros(len(lig_mask), device=self.device)
829
            pocket_mask_fixed = torch.ones(len(pocket['mask']),
830
                                           device=self.device)
831
832
            xh_lig, xh_pocket, lig_mask, pocket_mask = self.ddpm.inpaint(
833
                ligand, pocket, lig_mask_fixed, pocket_mask_fixed,
834
                timesteps=timesteps, **kwargs)
835
836
        # Use conditional generation
837
        elif type(self.ddpm) == ConditionalDDPM:
838
            xh_lig, xh_pocket, lig_mask, pocket_mask = \
839
                self.ddpm.sample_given_pocket(pocket, num_nodes_lig,
840
                                              timesteps=timesteps)
841
842
        else:
843
            raise NotImplementedError
844
845
        # Move generated molecule back to the original pocket position
846
        pocket_com_after = scatter_mean(
847
            xh_pocket[:, :self.x_dims], pocket_mask, dim=0)
848
849
        xh_pocket[:, :self.x_dims] += \
850
            (pocket_com_before - pocket_com_after)[pocket_mask]
851
        xh_lig[:, :self.x_dims] += \
852
            (pocket_com_before - pocket_com_after)[lig_mask]
853
854
        # Build mol objects
855
        x = xh_lig[:, :self.x_dims].detach().cpu()
856
        atom_type = xh_lig[:, self.x_dims:].argmax(1).detach().cpu()
857
        lig_mask = lig_mask.cpu()
858
859
        molecules = []
860
        for mol_pc in zip(utils.batch_to_list(x, lig_mask),
861
                          utils.batch_to_list(atom_type, lig_mask)):
862
863
            mol = build_molecule(*mol_pc, self.dataset_info, add_coords=True)
864
            mol = process_molecule(mol,
865
                                   add_hydrogens=False,
866
                                   sanitize=sanitize,
867
                                   relax_iter=relax_iter,
868
                                   largest_frag=largest_frag)
869
            if mol is not None:
870
                molecules.append(mol)
871
872
        return molecules
873
874
    def configure_gradient_clipping(self, optimizer, optimizer_idx,
875
                                    gradient_clip_val, gradient_clip_algorithm):
876
877
        if not self.clip_grad:
878
            return
879
880
        # Allow gradient norm to be 150% + 2 * stdev of the recent history.
881
        max_grad_norm = 1.5 * self.gradnorm_queue.mean() + \
882
                        2 * self.gradnorm_queue.std()
883
884
        # Get current grad_norm
885
        params = [p for g in optimizer.param_groups for p in g['params']]
886
        grad_norm = utils.get_grad_norm(params)
887
888
        # Lightning will handle the gradient clipping
889
        self.clip_gradients(optimizer, gradient_clip_val=max_grad_norm,
890
                            gradient_clip_algorithm='norm')
891
892
        if float(grad_norm) > max_grad_norm:
893
            self.gradnorm_queue.add(float(max_grad_norm))
894
        else:
895
            self.gradnorm_queue.add(float(grad_norm))
896
897
        if float(grad_norm) > max_grad_norm:
898
            print(f'Clipped gradient with value {grad_norm:.1f} '
899
                  f'while allowed {max_grad_norm:.1f}')
900
901
902
class WeightSchedule:
903
    def __init__(self, T, max_weight, mode='linear'):
904
        if mode == 'linear':
905
            self.weights = torch.linspace(max_weight, 0, T + 1)
906
        elif mode == 'constant':
907
            self.weights = max_weight * torch.ones(T + 1)
908
        else:
909
            raise NotImplementedError(f'{mode} weight schedule is not '
910
                                      f'available.')
911
912
    def __call__(self, t_array):
913
        """ all values in t_array are assumed to be integers in [0, T] """
914
        return self.weights[t_array].to(t_array.device)