Diff of /src/util/utils.py [000000] .. [7d53f6]

Switch to unified view

a b/src/util/utils.py
1
import os
2
import time
3
import math
4
import datetime
5
import warnings
6
import itertools
7
from copy import deepcopy
8
from functools import partial
9
from collections import Counter
10
from multiprocessing import Pool
11
from statistics import mean
12
13
import numpy as np
14
import matplotlib.pyplot as plt
15
from matplotlib.lines import Line2D
16
from scipy.spatial.distance import cosine as cos_distance
17
18
import torch
19
import wandb
20
21
from rdkit import Chem, DataStructs, RDLogger
22
from rdkit.Chem import (
23
    AllChem,
24
    Draw,
25
    Descriptors,
26
    Lipinski,
27
    Crippen,
28
    rdMolDescriptors,
29
    FilterCatalog,
30
)
31
from rdkit.Chem.Scaffolds import MurckoScaffold
32
33
# Disable RDKit warnings
34
RDLogger.DisableLog("rdApp.*")
35
36
37
class Metrics(object):
38
    """
39
    Collection of static methods to compute various metrics for molecules.
40
    """
41
42
    @staticmethod
43
    def valid(x):
44
        """
45
        Checks whether the molecule is valid.
46
        
47
        Args:
48
            x: RDKit molecule object.
49
        
50
        Returns:
51
            bool: True if molecule is valid and has a non-empty SMILES representation.
52
        """
53
        return x is not None and Chem.MolToSmiles(x) != ''
54
55
    @staticmethod
56
    def tanimoto_sim_1v2(data1, data2):
57
        """
58
        Computes the average Tanimoto similarity for paired fingerprints.
59
        
60
        Args:
61
            data1: Fingerprint data for first set.
62
            data2: Fingerprint data for second set.
63
        
64
        Returns:
65
            float: The average Tanimoto similarity between corresponding fingerprints.
66
        """
67
        # Determine the minimum size between two arrays for pairing
68
        min_len = data1.size if data1.size > data2.size else data2
69
        sims = []
70
        for i in range(min_len):
71
            sim = DataStructs.FingerprintSimilarity(data1[i], data2[i])
72
            sims.append(sim)
73
        # Use 'mean' from statistics; note that variable 'sim' was used, corrected to use sims list.
74
        mean_sim = mean(sims)
75
        return mean_sim
76
77
    @staticmethod
78
    def mol_length(x):
79
        """
80
        Computes the length of the largest fragment (by character count) in a SMILES string.
81
        
82
        Args:
83
            x (str): SMILES string.
84
        
85
        Returns:
86
            int: Number of alphabetic characters in the longest fragment of the SMILES.
87
        """
88
        if x is not None:
89
            # Split at dots (.) and take the fragment with maximum length, then count alphabetic characters.
90
            return len([char for char in max(x.split(sep="."), key=len).upper() if char.isalpha()])
91
        else:
92
            return 0
93
94
    @staticmethod
95
    def max_component(data, max_len):
96
        """
97
        Returns the average normalized length of molecules in the dataset.
98
        
99
        Each molecule's length is computed and divided by max_len, then averaged.
100
        
101
        Args:
102
            data (iterable): Collection of SMILES strings.
103
            max_len (int): Maximum possible length for normalization.
104
        
105
        Returns:
106
            float: Normalized average length.
107
        """
108
        lengths = np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)
109
        return (lengths / max_len).mean()
110
111
    @staticmethod
112
    def mean_atom_type(data):
113
        """
114
        Computes the average number of unique atom types in the provided node data.
115
        
116
        Args:
117
            data (iterable): Iterable containing node data with unique atom types.
118
        
119
        Returns:
120
            float: The average count of unique atom types, subtracting one.
121
        """
122
        atom_types_used = []
123
        for i in data:
124
            # Assuming each element i has a .unique() method that returns unique atom types.
125
            atom_types_used.append(len(i.unique().tolist()))
126
        av_type = np.mean(atom_types_used) - 1
127
        return av_type
128
129
130
def mols2grid_image(mols, path):
131
    """
132
    Saves grid images for a list of molecules.
133
    
134
    For each molecule in the list, computes 2D coordinates and saves an image file.
135
    
136
    Args:
137
        mols (list): List of RDKit molecule objects.
138
        path (str): Directory where images will be saved.
139
    """
140
    # Replace None molecules with an empty molecule
141
    mols = [e if e is not None else Chem.RWMol() for e in mols]
142
143
    for i in range(len(mols)):
144
        if Metrics.valid(mols[i]):
145
            AllChem.Compute2DCoords(mols[i])
146
            file_path = os.path.join(path, "{}.png".format(i + 1))
147
            Draw.MolToFile(mols[i], file_path, size=(1200, 1200))
148
            # wandb.save(file_path)  # Optionally save to Weights & Biases
149
        else:
150
            continue
151
152
153
def save_smiles_matrices(mols, edges_hard, nodes_hard, path, data_source=None):
154
    """
155
    Saves the edge and node matrices along with SMILES strings to text files.
156
    
157
    Each file contains the edge matrix, node matrix, and SMILES representation for a molecule.
158
    
159
    Args:
160
        mols (list): List of RDKit molecule objects.
161
        edges_hard (torch.Tensor): Tensor of edge features.
162
        nodes_hard (torch.Tensor): Tensor of node features.
163
        path (str): Directory where files will be saved.
164
        data_source: Optional data source information (not used in function).
165
    """
166
    mols = [e if e is not None else Chem.RWMol() for e in mols]
167
168
    for i in range(len(mols)):
169
        if Metrics.valid(mols[i]):
170
            save_path = os.path.join(path, "{}.txt".format(i + 1))
171
            with open(save_path, "a") as f:
172
                np.savetxt(f, edges_hard[i].cpu().numpy(), header="edge matrix:\n", fmt='%1.2f')
173
                f.write("\n")
174
                np.savetxt(f, nodes_hard[i].cpu().numpy(), header="node matrix:\n", footer="\nsmiles:", fmt='%1.2f')
175
                f.write("\n")
176
            # Append the SMILES representation to the file
177
            with open(save_path, "a") as f:
178
                print(Chem.MolToSmiles(mols[i]), file=f)
179
            # wandb.save(save_path)  # Optionally save to Weights & Biases
180
        else:
181
            continue
182
183
def dense_to_sparse_with_attr(adj):
184
    """
185
    Converts a dense adjacency matrix to a sparse representation.
186
    
187
    Args:
188
        adj (torch.Tensor): Adjacency matrix tensor (2D or 3D) with square last two dimensions.
189
    
190
    Returns:
191
        tuple: A tuple containing indices and corresponding edge attributes.
192
    """
193
    assert adj.dim() >= 2 and adj.dim() <= 3
194
    assert adj.size(-1) == adj.size(-2)
195
196
    index = adj.nonzero(as_tuple=True)
197
    edge_attr = adj[index]
198
199
    if len(index) == 3:
200
        batch = index[0] * adj.size(-1)
201
        index = (batch + index[1], batch + index[2])
202
    return index, edge_attr
203
204
205
def mol_sample(sample_directory, edges, nodes, idx, i, matrices2mol, dataset_name):
206
    """
207
    Samples molecules from edge and node predictions, then saves grid images and text files.
208
    
209
    Args:
210
        sample_directory (str): Directory to save the samples.
211
        edges (torch.Tensor): Edge predictions tensor.
212
        nodes (torch.Tensor): Node predictions tensor.
213
        idx (int): Current index for naming the sample.
214
        i (int): Epoch/iteration index.
215
        matrices2mol (callable): Function to convert matrices to RDKit molecule.
216
        dataset_name (str): Name of the dataset for file naming.
217
    """
218
    sample_path = os.path.join(sample_directory, "{}_{}-epoch_iteration".format(idx + 1, i + 1))
219
    # Get the index of the maximum predicted feature along the last dimension
220
    g_edges_hat_sample = torch.max(edges, -1)[1]
221
    g_nodes_hat_sample = torch.max(nodes, -1)[1]
222
    # Convert matrices to molecule objects
223
    mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
224
                        strict=True, file_name=dataset_name)
225
           for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
226
227
    if not os.path.exists(sample_path):
228
        os.makedirs(sample_path)
229
230
    mols2grid_image(mol, sample_path)
231
    save_smiles_matrices(mol, g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), sample_path)
232
233
    # Remove the directory if no files were saved
234
    if len(os.listdir(sample_path)) == 0:
235
        os.rmdir(sample_path)
236
237
    print("Valid molecules are saved.")
238
    print("Valid matrices and smiles are saved")
239
240
241
def logging(log_path, start_time, i, idx, loss, save_path, drug_smiles, edge, node, 
242
            matrices2mol, dataset_name, real_adj, real_annot, drug_vecs):
243
    """
244
    Logs training statistics and evaluation metrics.
245
    
246
    The function generates molecules from predictions, computes various metrics such as
247
    validity, uniqueness, novelty, and similarity scores, and logs them using wandb and a file.
248
    
249
    Args:
250
        log_path (str): Path to save the log file.
251
        start_time (float): Start time to compute elapsed time.
252
        i (int): Current iteration index.
253
        idx (int): Current epoch index.
254
        loss (dict): Dictionary to update with loss and metric values.
255
        save_path (str): Directory path to save sample outputs.
256
        drug_smiles (list): List of reference drug SMILES.
257
        edge (torch.Tensor): Edge prediction tensor.
258
        node (torch.Tensor): Node prediction tensor.
259
        matrices2mol (callable): Function to convert matrices to molecules.
260
        dataset_name (str): Dataset name.
261
        real_adj (torch.Tensor): Ground truth adjacency matrix tensor.
262
        real_annot (torch.Tensor): Ground truth annotation tensor.
263
        drug_vecs (list): List of drug vectors for similarity calculation.
264
    """
265
    g_edges_hat_sample = torch.max(edge, -1)[1]
266
    g_nodes_hat_sample = torch.max(node, -1)[1]
267
268
    a_tensor_sample = torch.max(real_adj, -1)[1].float()
269
    x_tensor_sample = torch.max(real_annot, -1)[1].float()
270
271
    # Generate molecules from predictions and real data
272
    mols = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
273
                         strict=True, file_name=dataset_name)
274
            for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
275
    real_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
276
                              strict=True, file_name=dataset_name)
277
                for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
278
279
    # Compute average number of atom types
280
    atom_types_average = Metrics.mean_atom_type(g_nodes_hat_sample)
281
    real_smiles = [Chem.MolToSmiles(x) for x in real_mol if x is not None]
282
    gen_smiles = []
283
    uniq_smiles = []
284
    for line in mols:
285
        if line is not None:
286
            gen_smiles.append(Chem.MolToSmiles(line))
287
            uniq_smiles.append(Chem.MolToSmiles(line))
288
        elif line is None:
289
            gen_smiles.append(None)
290
291
    # Process SMILES to take the longest fragment if multiple are present
292
    gen_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in gen_smiles]
293
    uniq_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in uniq_smiles]
294
295
    # Save the generated SMILES to a text file
296
    sample_save_dir = os.path.join(save_path, "samples.txt")
297
    with open(sample_save_dir, "a") as f:
298
        for s in gen_smiles_saves:
299
            if s is not None:
300
                f.write(s + "\n")
301
302
    k = len(set(uniq_smiles_saves) - {None})
303
    et = time.time() - start_time
304
    et = str(datetime.timedelta(seconds=et))[:-7]
305
    log_str = "Elapsed [{}], Epoch/Iteration [{}/{}]".format(et, idx, i + 1)
306
    
307
    # Generate molecular fingerprints for similarity computations
308
    gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in mols if x is not None]
309
    chembl_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_mol if x is not None]
310
311
    # Compute evaluation metrics: validity, uniqueness, novelty, similarity scores, and average maximum molecule length.
312
    valid = fraction_valid(gen_smiles_saves)
313
    unique = fraction_unique(uniq_smiles_saves, k)
314
    novel_starting_mol = novelty(gen_smiles_saves, real_smiles)
315
    novel_akt = novelty(gen_smiles_saves, drug_smiles)
316
    if len(uniq_smiles_saves) == 0:
317
        snn_chembl = 0
318
        snn_akt = 0
319
        maxlen = 0
320
    else:
321
        snn_chembl = average_agg_tanimoto(np.array(chembl_vecs), np.array(gen_vecs))
322
        snn_akt = average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs))
323
        maxlen = Metrics.max_component(uniq_smiles_saves, 45)
324
325
    # Update loss dictionary with computed metrics
326
    loss.update({
327
        'Validity': valid,
328
        'Uniqueness': unique,
329
        'Novelty': novel_starting_mol,
330
        'Novelty_akt': novel_akt,
331
        'SNN_chembl': snn_chembl,
332
        'SNN_akt': snn_akt,
333
        'MaxLen': maxlen,
334
        'Atom_types': atom_types_average
335
    })
336
337
    # Log metrics using wandb
338
    wandb.log({
339
        "Validity": valid,
340
        "Uniqueness": unique,
341
        "Novelty": novel_starting_mol,
342
        "Novelty_akt": novel_akt,
343
        "SNN_chembl": snn_chembl,
344
        "SNN_akt": snn_akt,
345
        "MaxLen": maxlen,
346
        "Atom_types": atom_types_average
347
    })
348
349
    # Append each metric to the log string and write to the log file
350
    for tag, value in loss.items():
351
        log_str += ", {}: {:.4f}".format(tag, value)
352
    with open(log_path, "a") as f:
353
        f.write(log_str + "\n")
354
    print(log_str)
355
    print("\n")
356
357
358
def plot_grad_flow(named_parameters, model, itera, epoch, grad_flow_directory):
359
    """
360
    Plots the gradients flowing through different layers during training.
361
    
362
    This is useful to check for possible gradient vanishing or exploding problems.
363
    
364
    Args:
365
        named_parameters (iterable): Iterable of (name, parameter) tuples from the model.
366
        model (str): Name of the model (used for saving the plot).
367
        itera (int): Iteration index.
368
        epoch (int): Current epoch.
369
        grad_flow_directory (str): Directory to save the gradient flow plot.
370
    """
371
    ave_grads = []
372
    max_grads = []
373
    layers = []
374
    for n, p in named_parameters:
375
        if p.requires_grad and ("bias" not in n):
376
            layers.append(n)
377
            ave_grads.append(p.grad.abs().mean().cpu())
378
            max_grads.append(p.grad.abs().max().cpu())
379
    # Plot maximum gradients and average gradients for each layer
380
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
381
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
382
    plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
383
    plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
384
    plt.xlim(left=0, right=len(ave_grads))
385
    plt.ylim(bottom=-0.001, top=1)  # Zoom in on lower gradient regions
386
    plt.xlabel("Layers")
387
    plt.ylabel("Average Gradient")
388
    plt.title("Gradient Flow")
389
    plt.grid(True)
390
    plt.legend([
391
        Line2D([0], [0], color="c", lw=4),
392
        Line2D([0], [0], color="b", lw=4),
393
        Line2D([0], [0], color="k", lw=4)
394
    ], ['max-gradient', 'mean-gradient', 'zero-gradient'])
395
    # Save the plot to the specified directory
396
    plt.savefig(os.path.join(grad_flow_directory, "weights_" + model + "_" + str(itera) + "_" + str(epoch) + ".png"), dpi=500, bbox_inches='tight')
397
398
399
def get_mol(smiles_or_mol):
400
    """
401
    Loads a SMILES string or molecule into an RDKit molecule object.
402
    
403
    Args:
404
        smiles_or_mol (str or RDKit Mol): SMILES string or RDKit molecule.
405
    
406
    Returns:
407
        RDKit Mol or None: Sanitized molecule object, or None if invalid.
408
    """
409
    if isinstance(smiles_or_mol, str):
410
        if len(smiles_or_mol) == 0:
411
            return None
412
        mol = Chem.MolFromSmiles(smiles_or_mol)
413
        if mol is None:
414
            return None
415
        try:
416
            Chem.SanitizeMol(mol)
417
        except ValueError:
418
            return None
419
        return mol
420
    return smiles_or_mol
421
422
423
def mapper(n_jobs):
424
    """
425
    Returns a mapping function for parallel or serial processing.
426
    
427
    If n_jobs == 1, returns the built-in map function.
428
    If n_jobs > 1, returns a function that uses a multiprocessing pool.
429
    
430
    Args:
431
        n_jobs (int or pool object): Number of jobs or a Pool instance.
432
    
433
    Returns:
434
        callable: A function that acts like map.
435
    """
436
    if n_jobs == 1:
437
        def _mapper(*args, **kwargs):
438
            return list(map(*args, **kwargs))
439
        return _mapper
440
    if isinstance(n_jobs, int):
441
        pool = Pool(n_jobs)
442
        def _mapper(*args, **kwargs):
443
            try:
444
                result = pool.map(*args, **kwargs)
445
            finally:
446
                pool.terminate()
447
            return result
448
        return _mapper
449
    return n_jobs.map
450
451
452
def remove_invalid(gen, canonize=True, n_jobs=1):
453
    """
454
    Removes invalid molecules from the provided dataset.
455
    
456
    Optionally canonizes the SMILES strings.
457
    
458
    Args:
459
        gen (list): List of SMILES strings.
460
        canonize (bool): Whether to convert to canonical SMILES.
461
        n_jobs (int): Number of parallel jobs.
462
    
463
    Returns:
464
        list: Filtered list of valid molecules.
465
    """
466
    if not canonize:
467
        mols = mapper(n_jobs)(get_mol, gen)
468
        return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
469
    return [x for x in mapper(n_jobs)(canonic_smiles, gen) if x is not None]
470
471
472
def fraction_valid(gen, n_jobs=1):
473
    """
474
    Computes the fraction of valid molecules in the dataset.
475
    
476
    Args:
477
        gen (list): List of SMILES strings.
478
        n_jobs (int): Number of parallel jobs.
479
    
480
    Returns:
481
        float: Fraction of molecules that are valid.
482
    """
483
    gen = mapper(n_jobs)(get_mol, gen)
484
    return 1 - gen.count(None) / len(gen)
485
486
487
def canonic_smiles(smiles_or_mol):
488
    """
489
    Converts a SMILES string or molecule to its canonical SMILES.
490
    
491
    Args:
492
        smiles_or_mol (str or RDKit Mol): Input molecule.
493
    
494
    Returns:
495
        str or None: Canonical SMILES string or None if invalid.
496
    """
497
    mol = get_mol(smiles_or_mol)
498
    if mol is None:
499
        return None
500
    return Chem.MolToSmiles(mol)
501
502
503
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
504
    """
505
    Computes the fraction of unique molecules.
506
    
507
    Optionally computes unique@k, where only the first k molecules are considered.
508
    
509
    Args:
510
        gen (list): List of SMILES strings.
511
        k (int): Optional cutoff for unique@k computation.
512
        n_jobs (int): Number of parallel jobs.
513
        check_validity (bool): Whether to check for validity of molecules.
514
    
515
    Returns:
516
        float: Fraction of unique molecules.
517
    """
518
    if k is not None:
519
        if len(gen) < k:
520
            warnings.warn("Can't compute unique@{}.".format(k) +
521
                          " gen contains only {} molecules".format(len(gen)))
522
        gen = gen[:k]
523
    if check_validity:
524
        canonic = list(mapper(n_jobs)(canonic_smiles, gen))
525
        canonic = [i for i in canonic if i is not None]
526
    set_cannonic = set(canonic)
527
    return 0 if len(canonic) == 0 else len(set_cannonic) / len(canonic)
528
529
530
def novelty(gen, train, n_jobs=1):
531
    """
532
    Computes the novelty score of generated molecules.
533
    
534
    Novelty is defined as the fraction of generated molecules that do not appear in the training set.
535
    
536
    Args:
537
        gen (list): List of generated SMILES strings.
538
        train (list): List of training SMILES strings.
539
        n_jobs (int): Number of parallel jobs.
540
    
541
    Returns:
542
        float: Novelty score.
543
    """
544
    gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
545
    gen_smiles_set = set(gen_smiles) - {None}
546
    train_set = set(train)
547
    return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
548
549
550
def internal_diversity(gen):
551
    """
552
    Computes the internal diversity of a set of molecules.
553
    
554
    Internal diversity is defined as one minus the average Tanimoto similarity between all pairs.
555
    
556
    Args:
557
        gen: Array-like representation of molecules.
558
    
559
    Returns:
560
        tuple: Mean and standard deviation of internal diversity.
561
    """
562
    diversity = [1 - x for x in average_agg_tanimoto(gen, gen, agg="mean", intdiv=True)]
563
    return np.mean(diversity), np.std(diversity)
564
565
566
def average_agg_tanimoto(stock_vecs, gen_vecs, batch_size=5000, agg='max', device='cpu', p=1, intdiv=False):
567
    """
568
    Computes the average aggregated Tanimoto similarity between two sets of molecular fingerprints.
569
    
570
    For each fingerprint in gen_vecs, finds the closest (max or mean) similarity with fingerprints in stock_vecs.
571
    
572
    Args:
573
        stock_vecs (numpy.ndarray): Array of fingerprint vectors from the reference set.
574
        gen_vecs (numpy.ndarray): Array of fingerprint vectors from the generated set.
575
        batch_size (int): Batch size for processing fingerprints.
576
        agg (str): Aggregation method, either 'max' or 'mean'.
577
        device (str): Device to perform computations on.
578
        p (int): Power for averaging.
579
        intdiv (bool): Whether to return individual similarities or the average.
580
    
581
    Returns:
582
        float or numpy.ndarray: Average aggregated Tanimoto similarity or array of individual scores.
583
    """
584
    assert agg in ['max', 'mean'], "Can aggregate only max or mean"
585
    agg_tanimoto = np.zeros(len(gen_vecs))
586
    total = np.zeros(len(gen_vecs))
587
    for j in range(0, stock_vecs.shape[0], batch_size):
588
        x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
589
        for i in range(0, gen_vecs.shape[0], batch_size):
590
            y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
591
            y_gen = y_gen.transpose(0, 1)
592
            tp = torch.mm(x_stock, y_gen)
593
            # Compute Jaccard/Tanimoto similarity
594
            jac = (tp / (x_stock.sum(1, keepdim=True) + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
595
            jac[np.isnan(jac)] = 1
596
            if p != 1:
597
                jac = jac ** p
598
            if agg == 'max':
599
                agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
600
                    agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
601
            elif agg == 'mean':
602
                agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
603
                total[i:i + y_gen.shape[1]] += jac.shape[0]
604
    if agg == 'mean':
605
        agg_tanimoto /= total
606
    if p != 1:
607
        agg_tanimoto = (agg_tanimoto) ** (1 / p)
608
    if intdiv:
609
        return agg_tanimoto
610
    else:
611
        return np.mean(agg_tanimoto)
612
613
614
def str2bool(v):
615
    """
616
    Converts a string to a boolean.
617
    
618
    Args:
619
        v (str): Input string.
620
    
621
    Returns:
622
        bool: True if the string is 'true' (case insensitive), else False.
623
    """
624
    return v.lower() in ('true')
625
626
627
def obey_lipinski(mol):
628
    """
629
    Checks if a molecule obeys Lipinski's Rule of Five.
630
    
631
    The function evaluates weight, hydrogen bond donors and acceptors, logP, and rotatable bonds.
632
    
633
    Args:
634
        mol (RDKit Mol): Molecule object.
635
    
636
    Returns:
637
        int: Number of Lipinski rules satisfied.
638
    """
639
    mol = deepcopy(mol)
640
    Chem.SanitizeMol(mol)
641
    rule_1 = Descriptors.ExactMolWt(mol) < 500
642
    rule_2 = Lipinski.NumHDonors(mol) <= 5
643
    rule_3 = Lipinski.NumHAcceptors(mol) <= 10
644
    rule_4 = (logp := Crippen.MolLogP(mol) >= -2) & (logp <= 5)
645
    rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
646
    return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
647
648
649
def obey_veber(mol):
650
    """
651
    Checks if a molecule obeys Veber's rules.
652
    
653
    Veber's rules focus on the number of rotatable bonds and topological polar surface area.
654
    
655
    Args:
656
        mol (RDKit Mol): Molecule object.
657
    
658
    Returns:
659
        int: Number of Veber's rules satisfied.
660
    """
661
    mol = deepcopy(mol)
662
    Chem.SanitizeMol(mol)
663
    rule_1 = rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
664
    rule_2 = rdMolDescriptors.CalcTPSA(mol) <= 140
665
    return np.sum([int(a) for a in [rule_1, rule_2]])
666
667
668
def load_pains_filters():
669
    """
670
    Loads the PAINS (Pan-Assay INterference compoundS) filters A, B, and C.
671
    
672
    Returns:
673
        FilterCatalog: An RDKit FilterCatalog object containing PAINS filters.
674
    """
675
    params = FilterCatalog.FilterCatalogParams()
676
    params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_A)
677
    params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_B)
678
    params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_C)
679
    catalog = FilterCatalog.FilterCatalog(params)
680
    return catalog
681
682
683
def is_pains(mol, catalog):
684
    """
685
    Checks if the given molecule is a PAINS compound.
686
    
687
    Args:
688
        mol (RDKit Mol): Molecule object.
689
        catalog (FilterCatalog): A catalog of PAINS filters.
690
    
691
    Returns:
692
        bool: True if the molecule matches a PAINS filter, else False.
693
    """
694
    entry = catalog.GetFirstMatch(mol)
695
    return entry is not None
696
697
698
def mapper(n_jobs):
699
    """
700
    Returns a mapping function for parallel or serial processing.
701
    
702
    If n_jobs == 1, returns the built-in map function.
703
    If n_jobs > 1, returns a function that uses a multiprocessing pool.
704
    
705
    Args:
706
        n_jobs (int or pool object): Number of jobs or a Pool instance.
707
    
708
    Returns:
709
        callable: A function that acts like map.
710
    """
711
    if n_jobs == 1:
712
        def _mapper(*args, **kwargs):
713
            return list(map(*args, **kwargs))
714
        return _mapper
715
    if isinstance(n_jobs, int):
716
        pool = Pool(n_jobs)
717
        def _mapper(*args, **kwargs):
718
            try:
719
                result = pool.map(*args, **kwargs)
720
            finally:
721
                pool.terminate()
722
            return result
723
        return _mapper
724
    return n_jobs.map
725
726
727
def fragmenter(mol):
728
    """
729
    Fragments a molecule using BRICS and returns a list of fragment SMILES.
730
    
731
    Args:
732
        mol (str or RDKit Mol): Input molecule.
733
    
734
    Returns:
735
        list: List of fragment SMILES strings.
736
    """
737
    fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol))
738
    fgs_smi = Chem.MolToSmiles(fgs).split(".")
739
    return fgs_smi
740
741
742
def get_mol(smiles_or_mol):
743
    """
744
    Loads a SMILES string or molecule into an RDKit molecule object.
745
    
746
    Args:
747
        smiles_or_mol (str or RDKit Mol): SMILES string or molecule.
748
    
749
    Returns:
750
        RDKit Mol or None: Sanitized molecule object or None if invalid.
751
    """
752
    if isinstance(smiles_or_mol, str):
753
        if len(smiles_or_mol) == 0:
754
            return None
755
        mol = Chem.MolFromSmiles(smiles_or_mol)
756
        if mol is None:
757
            return None
758
        try:
759
            Chem.SanitizeMol(mol)
760
        except ValueError:
761
            return None
762
        return mol
763
    return smiles_or_mol
764
765
766
def compute_fragments(mol_list, n_jobs=1):
767
    """
768
    Fragments a list of molecules using BRICS and returns a counter of fragment occurrences.
769
    
770
    Args:
771
        mol_list (list): List of molecules (SMILES or RDKit Mol).
772
        n_jobs (int): Number of parallel jobs.
773
    
774
    Returns:
775
        Counter: A Counter dictionary mapping fragment SMILES to counts.
776
    """
777
    fragments = Counter()
778
    for mol_frag in mapper(n_jobs)(fragmenter, mol_list):
779
        fragments.update(mol_frag)
780
    return fragments
781
782
783
def compute_scaffolds(mol_list, n_jobs=1, min_rings=2):
784
    """
785
    Extracts scaffolds from a list of molecules as canonical SMILES.
786
    
787
    Only scaffolds with at least min_rings rings are considered.
788
    
789
    Args:
790
        mol_list (list): List of molecules.
791
        n_jobs (int): Number of parallel jobs.
792
        min_rings (int): Minimum number of rings required in a scaffold.
793
    
794
    Returns:
795
        Counter: A Counter mapping scaffold SMILES to counts.
796
    """
797
    scaffolds = Counter()
798
    map_ = mapper(n_jobs)
799
    scaffolds = Counter(map_(partial(compute_scaffold, min_rings=min_rings), mol_list))
800
    if None in scaffolds:
801
        scaffolds.pop(None)
802
    return scaffolds
803
804
805
def get_n_rings(mol):
806
    """
807
    Computes the number of rings in a molecule.
808
    
809
    Args:
810
        mol (RDKit Mol): Molecule object.
811
    
812
    Returns:
813
        int: Number of rings.
814
    """
815
    return mol.GetRingInfo().NumRings()
816
817
818
def compute_scaffold(mol, min_rings=2):
819
    """
820
    Computes the Murcko scaffold of a molecule and returns its canonical SMILES if it has enough rings.
821
    
822
    Args:
823
        mol (str or RDKit Mol): Input molecule.
824
        min_rings (int): Minimum number of rings required.
825
    
826
    Returns:
827
        str or None: Canonical SMILES of the scaffold if valid, else None.
828
    """
829
    mol = get_mol(mol)
830
    try:
831
        scaffold = MurckoScaffold.GetScaffoldForMol(mol)
832
    except (ValueError, RuntimeError):
833
        return None
834
    n_rings = get_n_rings(scaffold)
835
    scaffold_smiles = Chem.MolToSmiles(scaffold)
836
    if scaffold_smiles == '' or n_rings < min_rings:
837
        return None
838
    return scaffold_smiles
839
840
841
class Metric:
842
    """
843
    Abstract base class for chemical metrics.
844
    
845
    Derived classes should implement the precalc and metric methods.
846
    """
847
    def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs):
848
        self.n_jobs = n_jobs
849
        self.device = device
850
        self.batch_size = batch_size
851
        for k, v in kwargs.items():
852
            setattr(self, k, v)
853
854
    def __call__(self, ref=None, gen=None, pref=None, pgen=None):
855
        """
856
        Computes the metric between reference and generated molecules.
857
        
858
        Exactly one of ref or pref, and gen or pgen should be provided.
859
        
860
        Args:
861
            ref: Reference molecule list.
862
            gen: Generated molecule list.
863
            pref: Precalculated reference metric.
864
            pgen: Precalculated generated metric.
865
        
866
        Returns:
867
            Metric value computed by the metric method.
868
        """
869
        assert (ref is None) != (pref is None), "specify ref xor pref"
870
        assert (gen is None) != (pgen is None), "specify gen xor pgen"
871
        if pref is None:
872
            pref = self.precalc(ref)
873
        if pgen is None:
874
            pgen = self.precalc(gen)
875
        return self.metric(pref, pgen)
876
877
    def precalc(self, molecules):
878
        """
879
        Pre-calculates necessary representations from a list of molecules.
880
        Should be implemented by derived classes.
881
        """
882
        raise NotImplementedError
883
884
    def metric(self, pref, pgen):
885
        """
886
        Computes the metric given precalculated representations.
887
        Should be implemented by derived classes.
888
        """
889
        raise NotImplementedError
890
891
892
class FragMetric(Metric):
893
    """
894
    Metrics based on molecular fragments.
895
    """
896
    def precalc(self, mols):
897
        return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)}
898
899
    def metric(self, pref, pgen):
900
        return cos_similarity(pref['frag'], pgen['frag'])
901
902
903
class ScafMetric(Metric):
904
    """
905
    Metrics based on molecular scaffolds.
906
    """
907
    def precalc(self, mols):
908
        return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)}
909
910
    def metric(self, pref, pgen):
911
        return cos_similarity(pref['scaf'], pgen['scaf'])
912
913
914
def cos_similarity(ref_counts, gen_counts):
915
    """
916
    Computes cosine similarity between two molecular vectors.
917
    
918
    Args:
919
        ref_counts (dict): Reference molecular vectors.
920
        gen_counts (dict): Generated molecular vectors.
921
    
922
    Returns:
923
        float: Cosine similarity between the two molecular vectors.
924
    """
925
    if len(ref_counts) == 0 or len(gen_counts) == 0:
926
        return np.nan
927
    keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys()))
928
    ref_vec = np.array([ref_counts.get(k, 0) for k in keys])
929
    gen_vec = np.array([gen_counts.get(k, 0) for k in keys])
930
    return 1 - cos_distance(ref_vec, gen_vec)