Diff of /src/util/utils.py [000000] .. [7d53f6]

Switch to side-by-side view

--- a
+++ b/src/util/utils.py
@@ -0,0 +1,930 @@
+import os
+import time
+import math
+import datetime
+import warnings
+import itertools
+from copy import deepcopy
+from functools import partial
+from collections import Counter
+from multiprocessing import Pool
+from statistics import mean
+
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.lines import Line2D
+from scipy.spatial.distance import cosine as cos_distance
+
+import torch
+import wandb
+
+from rdkit import Chem, DataStructs, RDLogger
+from rdkit.Chem import (
+    AllChem,
+    Draw,
+    Descriptors,
+    Lipinski,
+    Crippen,
+    rdMolDescriptors,
+    FilterCatalog,
+)
+from rdkit.Chem.Scaffolds import MurckoScaffold
+
+# Disable RDKit warnings
+RDLogger.DisableLog("rdApp.*")
+
+
+class Metrics(object):
+    """
+    Collection of static methods to compute various metrics for molecules.
+    """
+
+    @staticmethod
+    def valid(x):
+        """
+        Checks whether the molecule is valid.
+        
+        Args:
+            x: RDKit molecule object.
+        
+        Returns:
+            bool: True if molecule is valid and has a non-empty SMILES representation.
+        """
+        return x is not None and Chem.MolToSmiles(x) != ''
+
+    @staticmethod
+    def tanimoto_sim_1v2(data1, data2):
+        """
+        Computes the average Tanimoto similarity for paired fingerprints.
+        
+        Args:
+            data1: Fingerprint data for first set.
+            data2: Fingerprint data for second set.
+        
+        Returns:
+            float: The average Tanimoto similarity between corresponding fingerprints.
+        """
+        # Determine the minimum size between two arrays for pairing
+        min_len = data1.size if data1.size > data2.size else data2
+        sims = []
+        for i in range(min_len):
+            sim = DataStructs.FingerprintSimilarity(data1[i], data2[i])
+            sims.append(sim)
+        # Use 'mean' from statistics; note that variable 'sim' was used, corrected to use sims list.
+        mean_sim = mean(sims)
+        return mean_sim
+
+    @staticmethod
+    def mol_length(x):
+        """
+        Computes the length of the largest fragment (by character count) in a SMILES string.
+        
+        Args:
+            x (str): SMILES string.
+        
+        Returns:
+            int: Number of alphabetic characters in the longest fragment of the SMILES.
+        """
+        if x is not None:
+            # Split at dots (.) and take the fragment with maximum length, then count alphabetic characters.
+            return len([char for char in max(x.split(sep="."), key=len).upper() if char.isalpha()])
+        else:
+            return 0
+
+    @staticmethod
+    def max_component(data, max_len):
+        """
+        Returns the average normalized length of molecules in the dataset.
+        
+        Each molecule's length is computed and divided by max_len, then averaged.
+        
+        Args:
+            data (iterable): Collection of SMILES strings.
+            max_len (int): Maximum possible length for normalization.
+        
+        Returns:
+            float: Normalized average length.
+        """
+        lengths = np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)
+        return (lengths / max_len).mean()
+
+    @staticmethod
+    def mean_atom_type(data):
+        """
+        Computes the average number of unique atom types in the provided node data.
+        
+        Args:
+            data (iterable): Iterable containing node data with unique atom types.
+        
+        Returns:
+            float: The average count of unique atom types, subtracting one.
+        """
+        atom_types_used = []
+        for i in data:
+            # Assuming each element i has a .unique() method that returns unique atom types.
+            atom_types_used.append(len(i.unique().tolist()))
+        av_type = np.mean(atom_types_used) - 1
+        return av_type
+
+
+def mols2grid_image(mols, path):
+    """
+    Saves grid images for a list of molecules.
+    
+    For each molecule in the list, computes 2D coordinates and saves an image file.
+    
+    Args:
+        mols (list): List of RDKit molecule objects.
+        path (str): Directory where images will be saved.
+    """
+    # Replace None molecules with an empty molecule
+    mols = [e if e is not None else Chem.RWMol() for e in mols]
+
+    for i in range(len(mols)):
+        if Metrics.valid(mols[i]):
+            AllChem.Compute2DCoords(mols[i])
+            file_path = os.path.join(path, "{}.png".format(i + 1))
+            Draw.MolToFile(mols[i], file_path, size=(1200, 1200))
+            # wandb.save(file_path)  # Optionally save to Weights & Biases
+        else:
+            continue
+
+
+def save_smiles_matrices(mols, edges_hard, nodes_hard, path, data_source=None):
+    """
+    Saves the edge and node matrices along with SMILES strings to text files.
+    
+    Each file contains the edge matrix, node matrix, and SMILES representation for a molecule.
+    
+    Args:
+        mols (list): List of RDKit molecule objects.
+        edges_hard (torch.Tensor): Tensor of edge features.
+        nodes_hard (torch.Tensor): Tensor of node features.
+        path (str): Directory where files will be saved.
+        data_source: Optional data source information (not used in function).
+    """
+    mols = [e if e is not None else Chem.RWMol() for e in mols]
+
+    for i in range(len(mols)):
+        if Metrics.valid(mols[i]):
+            save_path = os.path.join(path, "{}.txt".format(i + 1))
+            with open(save_path, "a") as f:
+                np.savetxt(f, edges_hard[i].cpu().numpy(), header="edge matrix:\n", fmt='%1.2f')
+                f.write("\n")
+                np.savetxt(f, nodes_hard[i].cpu().numpy(), header="node matrix:\n", footer="\nsmiles:", fmt='%1.2f')
+                f.write("\n")
+            # Append the SMILES representation to the file
+            with open(save_path, "a") as f:
+                print(Chem.MolToSmiles(mols[i]), file=f)
+            # wandb.save(save_path)  # Optionally save to Weights & Biases
+        else:
+            continue
+
+def dense_to_sparse_with_attr(adj):
+    """
+    Converts a dense adjacency matrix to a sparse representation.
+    
+    Args:
+        adj (torch.Tensor): Adjacency matrix tensor (2D or 3D) with square last two dimensions.
+    
+    Returns:
+        tuple: A tuple containing indices and corresponding edge attributes.
+    """
+    assert adj.dim() >= 2 and adj.dim() <= 3
+    assert adj.size(-1) == adj.size(-2)
+
+    index = adj.nonzero(as_tuple=True)
+    edge_attr = adj[index]
+
+    if len(index) == 3:
+        batch = index[0] * adj.size(-1)
+        index = (batch + index[1], batch + index[2])
+    return index, edge_attr
+
+
+def mol_sample(sample_directory, edges, nodes, idx, i, matrices2mol, dataset_name):
+    """
+    Samples molecules from edge and node predictions, then saves grid images and text files.
+    
+    Args:
+        sample_directory (str): Directory to save the samples.
+        edges (torch.Tensor): Edge predictions tensor.
+        nodes (torch.Tensor): Node predictions tensor.
+        idx (int): Current index for naming the sample.
+        i (int): Epoch/iteration index.
+        matrices2mol (callable): Function to convert matrices to RDKit molecule.
+        dataset_name (str): Name of the dataset for file naming.
+    """
+    sample_path = os.path.join(sample_directory, "{}_{}-epoch_iteration".format(idx + 1, i + 1))
+    # Get the index of the maximum predicted feature along the last dimension
+    g_edges_hat_sample = torch.max(edges, -1)[1]
+    g_nodes_hat_sample = torch.max(nodes, -1)[1]
+    # Convert matrices to molecule objects
+    mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
+                        strict=True, file_name=dataset_name)
+           for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
+
+    if not os.path.exists(sample_path):
+        os.makedirs(sample_path)
+
+    mols2grid_image(mol, sample_path)
+    save_smiles_matrices(mol, g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), sample_path)
+
+    # Remove the directory if no files were saved
+    if len(os.listdir(sample_path)) == 0:
+        os.rmdir(sample_path)
+
+    print("Valid molecules are saved.")
+    print("Valid matrices and smiles are saved")
+
+
+def logging(log_path, start_time, i, idx, loss, save_path, drug_smiles, edge, node, 
+            matrices2mol, dataset_name, real_adj, real_annot, drug_vecs):
+    """
+    Logs training statistics and evaluation metrics.
+    
+    The function generates molecules from predictions, computes various metrics such as
+    validity, uniqueness, novelty, and similarity scores, and logs them using wandb and a file.
+    
+    Args:
+        log_path (str): Path to save the log file.
+        start_time (float): Start time to compute elapsed time.
+        i (int): Current iteration index.
+        idx (int): Current epoch index.
+        loss (dict): Dictionary to update with loss and metric values.
+        save_path (str): Directory path to save sample outputs.
+        drug_smiles (list): List of reference drug SMILES.
+        edge (torch.Tensor): Edge prediction tensor.
+        node (torch.Tensor): Node prediction tensor.
+        matrices2mol (callable): Function to convert matrices to molecules.
+        dataset_name (str): Dataset name.
+        real_adj (torch.Tensor): Ground truth adjacency matrix tensor.
+        real_annot (torch.Tensor): Ground truth annotation tensor.
+        drug_vecs (list): List of drug vectors for similarity calculation.
+    """
+    g_edges_hat_sample = torch.max(edge, -1)[1]
+    g_nodes_hat_sample = torch.max(node, -1)[1]
+
+    a_tensor_sample = torch.max(real_adj, -1)[1].float()
+    x_tensor_sample = torch.max(real_annot, -1)[1].float()
+
+    # Generate molecules from predictions and real data
+    mols = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
+                         strict=True, file_name=dataset_name)
+            for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
+    real_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
+                              strict=True, file_name=dataset_name)
+                for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
+
+    # Compute average number of atom types
+    atom_types_average = Metrics.mean_atom_type(g_nodes_hat_sample)
+    real_smiles = [Chem.MolToSmiles(x) for x in real_mol if x is not None]
+    gen_smiles = []
+    uniq_smiles = []
+    for line in mols:
+        if line is not None:
+            gen_smiles.append(Chem.MolToSmiles(line))
+            uniq_smiles.append(Chem.MolToSmiles(line))
+        elif line is None:
+            gen_smiles.append(None)
+
+    # Process SMILES to take the longest fragment if multiple are present
+    gen_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in gen_smiles]
+    uniq_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in uniq_smiles]
+
+    # Save the generated SMILES to a text file
+    sample_save_dir = os.path.join(save_path, "samples.txt")
+    with open(sample_save_dir, "a") as f:
+        for s in gen_smiles_saves:
+            if s is not None:
+                f.write(s + "\n")
+
+    k = len(set(uniq_smiles_saves) - {None})
+    et = time.time() - start_time
+    et = str(datetime.timedelta(seconds=et))[:-7]
+    log_str = "Elapsed [{}], Epoch/Iteration [{}/{}]".format(et, idx, i + 1)
+    
+    # Generate molecular fingerprints for similarity computations
+    gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in mols if x is not None]
+    chembl_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_mol if x is not None]
+
+    # Compute evaluation metrics: validity, uniqueness, novelty, similarity scores, and average maximum molecule length.
+    valid = fraction_valid(gen_smiles_saves)
+    unique = fraction_unique(uniq_smiles_saves, k)
+    novel_starting_mol = novelty(gen_smiles_saves, real_smiles)
+    novel_akt = novelty(gen_smiles_saves, drug_smiles)
+    if len(uniq_smiles_saves) == 0:
+        snn_chembl = 0
+        snn_akt = 0
+        maxlen = 0
+    else:
+        snn_chembl = average_agg_tanimoto(np.array(chembl_vecs), np.array(gen_vecs))
+        snn_akt = average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs))
+        maxlen = Metrics.max_component(uniq_smiles_saves, 45)
+
+    # Update loss dictionary with computed metrics
+    loss.update({
+        'Validity': valid,
+        'Uniqueness': unique,
+        'Novelty': novel_starting_mol,
+        'Novelty_akt': novel_akt,
+        'SNN_chembl': snn_chembl,
+        'SNN_akt': snn_akt,
+        'MaxLen': maxlen,
+        'Atom_types': atom_types_average
+    })
+
+    # Log metrics using wandb
+    wandb.log({
+        "Validity": valid,
+        "Uniqueness": unique,
+        "Novelty": novel_starting_mol,
+        "Novelty_akt": novel_akt,
+        "SNN_chembl": snn_chembl,
+        "SNN_akt": snn_akt,
+        "MaxLen": maxlen,
+        "Atom_types": atom_types_average
+    })
+
+    # Append each metric to the log string and write to the log file
+    for tag, value in loss.items():
+        log_str += ", {}: {:.4f}".format(tag, value)
+    with open(log_path, "a") as f:
+        f.write(log_str + "\n")
+    print(log_str)
+    print("\n")
+
+
+def plot_grad_flow(named_parameters, model, itera, epoch, grad_flow_directory):
+    """
+    Plots the gradients flowing through different layers during training.
+    
+    This is useful to check for possible gradient vanishing or exploding problems.
+    
+    Args:
+        named_parameters (iterable): Iterable of (name, parameter) tuples from the model.
+        model (str): Name of the model (used for saving the plot).
+        itera (int): Iteration index.
+        epoch (int): Current epoch.
+        grad_flow_directory (str): Directory to save the gradient flow plot.
+    """
+    ave_grads = []
+    max_grads = []
+    layers = []
+    for n, p in named_parameters:
+        if p.requires_grad and ("bias" not in n):
+            layers.append(n)
+            ave_grads.append(p.grad.abs().mean().cpu())
+            max_grads.append(p.grad.abs().max().cpu())
+    # Plot maximum gradients and average gradients for each layer
+    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
+    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
+    plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
+    plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
+    plt.xlim(left=0, right=len(ave_grads))
+    plt.ylim(bottom=-0.001, top=1)  # Zoom in on lower gradient regions
+    plt.xlabel("Layers")
+    plt.ylabel("Average Gradient")
+    plt.title("Gradient Flow")
+    plt.grid(True)
+    plt.legend([
+        Line2D([0], [0], color="c", lw=4),
+        Line2D([0], [0], color="b", lw=4),
+        Line2D([0], [0], color="k", lw=4)
+    ], ['max-gradient', 'mean-gradient', 'zero-gradient'])
+    # Save the plot to the specified directory
+    plt.savefig(os.path.join(grad_flow_directory, "weights_" + model + "_" + str(itera) + "_" + str(epoch) + ".png"), dpi=500, bbox_inches='tight')
+
+
+def get_mol(smiles_or_mol):
+    """
+    Loads a SMILES string or molecule into an RDKit molecule object.
+    
+    Args:
+        smiles_or_mol (str or RDKit Mol): SMILES string or RDKit molecule.
+    
+    Returns:
+        RDKit Mol or None: Sanitized molecule object, or None if invalid.
+    """
+    if isinstance(smiles_or_mol, str):
+        if len(smiles_or_mol) == 0:
+            return None
+        mol = Chem.MolFromSmiles(smiles_or_mol)
+        if mol is None:
+            return None
+        try:
+            Chem.SanitizeMol(mol)
+        except ValueError:
+            return None
+        return mol
+    return smiles_or_mol
+
+
+def mapper(n_jobs):
+    """
+    Returns a mapping function for parallel or serial processing.
+    
+    If n_jobs == 1, returns the built-in map function.
+    If n_jobs > 1, returns a function that uses a multiprocessing pool.
+    
+    Args:
+        n_jobs (int or pool object): Number of jobs or a Pool instance.
+    
+    Returns:
+        callable: A function that acts like map.
+    """
+    if n_jobs == 1:
+        def _mapper(*args, **kwargs):
+            return list(map(*args, **kwargs))
+        return _mapper
+    if isinstance(n_jobs, int):
+        pool = Pool(n_jobs)
+        def _mapper(*args, **kwargs):
+            try:
+                result = pool.map(*args, **kwargs)
+            finally:
+                pool.terminate()
+            return result
+        return _mapper
+    return n_jobs.map
+
+
+def remove_invalid(gen, canonize=True, n_jobs=1):
+    """
+    Removes invalid molecules from the provided dataset.
+    
+    Optionally canonizes the SMILES strings.
+    
+    Args:
+        gen (list): List of SMILES strings.
+        canonize (bool): Whether to convert to canonical SMILES.
+        n_jobs (int): Number of parallel jobs.
+    
+    Returns:
+        list: Filtered list of valid molecules.
+    """
+    if not canonize:
+        mols = mapper(n_jobs)(get_mol, gen)
+        return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
+    return [x for x in mapper(n_jobs)(canonic_smiles, gen) if x is not None]
+
+
+def fraction_valid(gen, n_jobs=1):
+    """
+    Computes the fraction of valid molecules in the dataset.
+    
+    Args:
+        gen (list): List of SMILES strings.
+        n_jobs (int): Number of parallel jobs.
+    
+    Returns:
+        float: Fraction of molecules that are valid.
+    """
+    gen = mapper(n_jobs)(get_mol, gen)
+    return 1 - gen.count(None) / len(gen)
+
+
+def canonic_smiles(smiles_or_mol):
+    """
+    Converts a SMILES string or molecule to its canonical SMILES.
+    
+    Args:
+        smiles_or_mol (str or RDKit Mol): Input molecule.
+    
+    Returns:
+        str or None: Canonical SMILES string or None if invalid.
+    """
+    mol = get_mol(smiles_or_mol)
+    if mol is None:
+        return None
+    return Chem.MolToSmiles(mol)
+
+
+def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
+    """
+    Computes the fraction of unique molecules.
+    
+    Optionally computes unique@k, where only the first k molecules are considered.
+    
+    Args:
+        gen (list): List of SMILES strings.
+        k (int): Optional cutoff for unique@k computation.
+        n_jobs (int): Number of parallel jobs.
+        check_validity (bool): Whether to check for validity of molecules.
+    
+    Returns:
+        float: Fraction of unique molecules.
+    """
+    if k is not None:
+        if len(gen) < k:
+            warnings.warn("Can't compute unique@{}.".format(k) +
+                          " gen contains only {} molecules".format(len(gen)))
+        gen = gen[:k]
+    if check_validity:
+        canonic = list(mapper(n_jobs)(canonic_smiles, gen))
+        canonic = [i for i in canonic if i is not None]
+    set_cannonic = set(canonic)
+    return 0 if len(canonic) == 0 else len(set_cannonic) / len(canonic)
+
+
+def novelty(gen, train, n_jobs=1):
+    """
+    Computes the novelty score of generated molecules.
+    
+    Novelty is defined as the fraction of generated molecules that do not appear in the training set.
+    
+    Args:
+        gen (list): List of generated SMILES strings.
+        train (list): List of training SMILES strings.
+        n_jobs (int): Number of parallel jobs.
+    
+    Returns:
+        float: Novelty score.
+    """
+    gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
+    gen_smiles_set = set(gen_smiles) - {None}
+    train_set = set(train)
+    return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
+
+
+def internal_diversity(gen):
+    """
+    Computes the internal diversity of a set of molecules.
+    
+    Internal diversity is defined as one minus the average Tanimoto similarity between all pairs.
+    
+    Args:
+        gen: Array-like representation of molecules.
+    
+    Returns:
+        tuple: Mean and standard deviation of internal diversity.
+    """
+    diversity = [1 - x for x in average_agg_tanimoto(gen, gen, agg="mean", intdiv=True)]
+    return np.mean(diversity), np.std(diversity)
+
+
+def average_agg_tanimoto(stock_vecs, gen_vecs, batch_size=5000, agg='max', device='cpu', p=1, intdiv=False):
+    """
+    Computes the average aggregated Tanimoto similarity between two sets of molecular fingerprints.
+    
+    For each fingerprint in gen_vecs, finds the closest (max or mean) similarity with fingerprints in stock_vecs.
+    
+    Args:
+        stock_vecs (numpy.ndarray): Array of fingerprint vectors from the reference set.
+        gen_vecs (numpy.ndarray): Array of fingerprint vectors from the generated set.
+        batch_size (int): Batch size for processing fingerprints.
+        agg (str): Aggregation method, either 'max' or 'mean'.
+        device (str): Device to perform computations on.
+        p (int): Power for averaging.
+        intdiv (bool): Whether to return individual similarities or the average.
+    
+    Returns:
+        float or numpy.ndarray: Average aggregated Tanimoto similarity or array of individual scores.
+    """
+    assert agg in ['max', 'mean'], "Can aggregate only max or mean"
+    agg_tanimoto = np.zeros(len(gen_vecs))
+    total = np.zeros(len(gen_vecs))
+    for j in range(0, stock_vecs.shape[0], batch_size):
+        x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
+        for i in range(0, gen_vecs.shape[0], batch_size):
+            y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
+            y_gen = y_gen.transpose(0, 1)
+            tp = torch.mm(x_stock, y_gen)
+            # Compute Jaccard/Tanimoto similarity
+            jac = (tp / (x_stock.sum(1, keepdim=True) + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
+            jac[np.isnan(jac)] = 1
+            if p != 1:
+                jac = jac ** p
+            if agg == 'max':
+                agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
+                    agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
+            elif agg == 'mean':
+                agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
+                total[i:i + y_gen.shape[1]] += jac.shape[0]
+    if agg == 'mean':
+        agg_tanimoto /= total
+    if p != 1:
+        agg_tanimoto = (agg_tanimoto) ** (1 / p)
+    if intdiv:
+        return agg_tanimoto
+    else:
+        return np.mean(agg_tanimoto)
+
+
+def str2bool(v):
+    """
+    Converts a string to a boolean.
+    
+    Args:
+        v (str): Input string.
+    
+    Returns:
+        bool: True if the string is 'true' (case insensitive), else False.
+    """
+    return v.lower() in ('true')
+
+
+def obey_lipinski(mol):
+    """
+    Checks if a molecule obeys Lipinski's Rule of Five.
+    
+    The function evaluates weight, hydrogen bond donors and acceptors, logP, and rotatable bonds.
+    
+    Args:
+        mol (RDKit Mol): Molecule object.
+    
+    Returns:
+        int: Number of Lipinski rules satisfied.
+    """
+    mol = deepcopy(mol)
+    Chem.SanitizeMol(mol)
+    rule_1 = Descriptors.ExactMolWt(mol) < 500
+    rule_2 = Lipinski.NumHDonors(mol) <= 5
+    rule_3 = Lipinski.NumHAcceptors(mol) <= 10
+    rule_4 = (logp := Crippen.MolLogP(mol) >= -2) & (logp <= 5)
+    rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
+    return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
+
+
+def obey_veber(mol):
+    """
+    Checks if a molecule obeys Veber's rules.
+    
+    Veber's rules focus on the number of rotatable bonds and topological polar surface area.
+    
+    Args:
+        mol (RDKit Mol): Molecule object.
+    
+    Returns:
+        int: Number of Veber's rules satisfied.
+    """
+    mol = deepcopy(mol)
+    Chem.SanitizeMol(mol)
+    rule_1 = rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
+    rule_2 = rdMolDescriptors.CalcTPSA(mol) <= 140
+    return np.sum([int(a) for a in [rule_1, rule_2]])
+
+
+def load_pains_filters():
+    """
+    Loads the PAINS (Pan-Assay INterference compoundS) filters A, B, and C.
+    
+    Returns:
+        FilterCatalog: An RDKit FilterCatalog object containing PAINS filters.
+    """
+    params = FilterCatalog.FilterCatalogParams()
+    params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_A)
+    params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_B)
+    params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_C)
+    catalog = FilterCatalog.FilterCatalog(params)
+    return catalog
+
+
+def is_pains(mol, catalog):
+    """
+    Checks if the given molecule is a PAINS compound.
+    
+    Args:
+        mol (RDKit Mol): Molecule object.
+        catalog (FilterCatalog): A catalog of PAINS filters.
+    
+    Returns:
+        bool: True if the molecule matches a PAINS filter, else False.
+    """
+    entry = catalog.GetFirstMatch(mol)
+    return entry is not None
+
+
+def mapper(n_jobs):
+    """
+    Returns a mapping function for parallel or serial processing.
+    
+    If n_jobs == 1, returns the built-in map function.
+    If n_jobs > 1, returns a function that uses a multiprocessing pool.
+    
+    Args:
+        n_jobs (int or pool object): Number of jobs or a Pool instance.
+    
+    Returns:
+        callable: A function that acts like map.
+    """
+    if n_jobs == 1:
+        def _mapper(*args, **kwargs):
+            return list(map(*args, **kwargs))
+        return _mapper
+    if isinstance(n_jobs, int):
+        pool = Pool(n_jobs)
+        def _mapper(*args, **kwargs):
+            try:
+                result = pool.map(*args, **kwargs)
+            finally:
+                pool.terminate()
+            return result
+        return _mapper
+    return n_jobs.map
+
+
+def fragmenter(mol):
+    """
+    Fragments a molecule using BRICS and returns a list of fragment SMILES.
+    
+    Args:
+        mol (str or RDKit Mol): Input molecule.
+    
+    Returns:
+        list: List of fragment SMILES strings.
+    """
+    fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol))
+    fgs_smi = Chem.MolToSmiles(fgs).split(".")
+    return fgs_smi
+
+
+def get_mol(smiles_or_mol):
+    """
+    Loads a SMILES string or molecule into an RDKit molecule object.
+    
+    Args:
+        smiles_or_mol (str or RDKit Mol): SMILES string or molecule.
+    
+    Returns:
+        RDKit Mol or None: Sanitized molecule object or None if invalid.
+    """
+    if isinstance(smiles_or_mol, str):
+        if len(smiles_or_mol) == 0:
+            return None
+        mol = Chem.MolFromSmiles(smiles_or_mol)
+        if mol is None:
+            return None
+        try:
+            Chem.SanitizeMol(mol)
+        except ValueError:
+            return None
+        return mol
+    return smiles_or_mol
+
+
+def compute_fragments(mol_list, n_jobs=1):
+    """
+    Fragments a list of molecules using BRICS and returns a counter of fragment occurrences.
+    
+    Args:
+        mol_list (list): List of molecules (SMILES or RDKit Mol).
+        n_jobs (int): Number of parallel jobs.
+    
+    Returns:
+        Counter: A Counter dictionary mapping fragment SMILES to counts.
+    """
+    fragments = Counter()
+    for mol_frag in mapper(n_jobs)(fragmenter, mol_list):
+        fragments.update(mol_frag)
+    return fragments
+
+
+def compute_scaffolds(mol_list, n_jobs=1, min_rings=2):
+    """
+    Extracts scaffolds from a list of molecules as canonical SMILES.
+    
+    Only scaffolds with at least min_rings rings are considered.
+    
+    Args:
+        mol_list (list): List of molecules.
+        n_jobs (int): Number of parallel jobs.
+        min_rings (int): Minimum number of rings required in a scaffold.
+    
+    Returns:
+        Counter: A Counter mapping scaffold SMILES to counts.
+    """
+    scaffolds = Counter()
+    map_ = mapper(n_jobs)
+    scaffolds = Counter(map_(partial(compute_scaffold, min_rings=min_rings), mol_list))
+    if None in scaffolds:
+        scaffolds.pop(None)
+    return scaffolds
+
+
+def get_n_rings(mol):
+    """
+    Computes the number of rings in a molecule.
+    
+    Args:
+        mol (RDKit Mol): Molecule object.
+    
+    Returns:
+        int: Number of rings.
+    """
+    return mol.GetRingInfo().NumRings()
+
+
+def compute_scaffold(mol, min_rings=2):
+    """
+    Computes the Murcko scaffold of a molecule and returns its canonical SMILES if it has enough rings.
+    
+    Args:
+        mol (str or RDKit Mol): Input molecule.
+        min_rings (int): Minimum number of rings required.
+    
+    Returns:
+        str or None: Canonical SMILES of the scaffold if valid, else None.
+    """
+    mol = get_mol(mol)
+    try:
+        scaffold = MurckoScaffold.GetScaffoldForMol(mol)
+    except (ValueError, RuntimeError):
+        return None
+    n_rings = get_n_rings(scaffold)
+    scaffold_smiles = Chem.MolToSmiles(scaffold)
+    if scaffold_smiles == '' or n_rings < min_rings:
+        return None
+    return scaffold_smiles
+
+
+class Metric:
+    """
+    Abstract base class for chemical metrics.
+    
+    Derived classes should implement the precalc and metric methods.
+    """
+    def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs):
+        self.n_jobs = n_jobs
+        self.device = device
+        self.batch_size = batch_size
+        for k, v in kwargs.items():
+            setattr(self, k, v)
+
+    def __call__(self, ref=None, gen=None, pref=None, pgen=None):
+        """
+        Computes the metric between reference and generated molecules.
+        
+        Exactly one of ref or pref, and gen or pgen should be provided.
+        
+        Args:
+            ref: Reference molecule list.
+            gen: Generated molecule list.
+            pref: Precalculated reference metric.
+            pgen: Precalculated generated metric.
+        
+        Returns:
+            Metric value computed by the metric method.
+        """
+        assert (ref is None) != (pref is None), "specify ref xor pref"
+        assert (gen is None) != (pgen is None), "specify gen xor pgen"
+        if pref is None:
+            pref = self.precalc(ref)
+        if pgen is None:
+            pgen = self.precalc(gen)
+        return self.metric(pref, pgen)
+
+    def precalc(self, molecules):
+        """
+        Pre-calculates necessary representations from a list of molecules.
+        Should be implemented by derived classes.
+        """
+        raise NotImplementedError
+
+    def metric(self, pref, pgen):
+        """
+        Computes the metric given precalculated representations.
+        Should be implemented by derived classes.
+        """
+        raise NotImplementedError
+
+
+class FragMetric(Metric):
+    """
+    Metrics based on molecular fragments.
+    """
+    def precalc(self, mols):
+        return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)}
+
+    def metric(self, pref, pgen):
+        return cos_similarity(pref['frag'], pgen['frag'])
+
+
+class ScafMetric(Metric):
+    """
+    Metrics based on molecular scaffolds.
+    """
+    def precalc(self, mols):
+        return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)}
+
+    def metric(self, pref, pgen):
+        return cos_similarity(pref['scaf'], pgen['scaf'])
+
+
+def cos_similarity(ref_counts, gen_counts):
+    """
+    Computes cosine similarity between two molecular vectors.
+    
+    Args:
+        ref_counts (dict): Reference molecular vectors.
+        gen_counts (dict): Generated molecular vectors.
+    
+    Returns:
+        float: Cosine similarity between the two molecular vectors.
+    """
+    if len(ref_counts) == 0 or len(gen_counts) == 0:
+        return np.nan
+    keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys()))
+    ref_vec = np.array([ref_counts.get(k, 0) for k in keys])
+    gen_vec = np.array([gen_counts.get(k, 0) for k in keys])
+    return 1 - cos_distance(ref_vec, gen_vec)
\ No newline at end of file