--- a +++ b/analysis/metrics.py @@ -0,0 +1,251 @@ +import numpy as np +from tqdm import tqdm +from rdkit import Chem, DataStructs +from rdkit.Chem import Descriptors, Crippen, Lipinski, QED +from analysis.SA_Score.sascorer import calculateScore + +from analysis.molecule_builder import build_molecule +from copy import deepcopy + + +class CategoricalDistribution: + EPS = 1e-10 + + def __init__(self, histogram_dict, mapping): + histogram = np.zeros(len(mapping)) + for k, v in histogram_dict.items(): + histogram[mapping[k]] = v + + # Normalize histogram + self.p = histogram / histogram.sum() + self.mapping = deepcopy(mapping) + + def kl_divergence(self, other_sample): + sample_histogram = np.zeros(len(self.mapping)) + for x in other_sample: + # sample_histogram[self.mapping[x]] += 1 + sample_histogram[x] += 1 + + # Normalize + q = sample_histogram / sample_histogram.sum() + + return -np.sum(self.p * np.log(q / self.p + self.EPS)) + + +def rdmol_to_smiles(rdmol): + mol = Chem.Mol(rdmol) + Chem.RemoveStereochemistry(mol) + mol = Chem.RemoveHs(mol) + return Chem.MolToSmiles(mol) + + +class BasicMolecularMetrics(object): + def __init__(self, dataset_info, dataset_smiles_list=None, + connectivity_thresh=1.0): + self.atom_decoder = dataset_info['atom_decoder'] + if dataset_smiles_list is not None: + dataset_smiles_list = set(dataset_smiles_list) + self.dataset_smiles_list = dataset_smiles_list + self.dataset_info = dataset_info + self.connectivity_thresh = connectivity_thresh + + def compute_validity(self, generated): + """ generated: list of couples (positions, atom_types)""" + if len(generated) < 1: + return [], 0.0 + + valid = [] + for mol in generated: + try: + Chem.SanitizeMol(mol) + except ValueError: + continue + + valid.append(mol) + + return valid, len(valid) / len(generated) + + def compute_connectivity(self, valid): + """ Consider molecule connected if its largest fragment contains at + least x% of all atoms, where x is determined by + self.connectivity_thresh (defaults to 100%). """ + if len(valid) < 1: + return [], 0.0 + + connected = [] + connected_smiles = [] + for mol in valid: + mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True) + largest_mol = \ + max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) + if largest_mol.GetNumAtoms() / mol.GetNumAtoms() >= self.connectivity_thresh: + smiles = rdmol_to_smiles(largest_mol) + if smiles is not None: + connected_smiles.append(smiles) + connected.append(largest_mol) + + return connected, len(connected_smiles) / len(valid), connected_smiles + + def compute_uniqueness(self, connected): + """ valid: list of SMILES strings.""" + if len(connected) < 1 or self.dataset_smiles_list is None: + return [], 0.0 + + return list(set(connected)), len(set(connected)) / len(connected) + + def compute_novelty(self, unique): + if len(unique) < 1: + return [], 0.0 + + num_novel = 0 + novel = [] + for smiles in unique: + if smiles not in self.dataset_smiles_list: + novel.append(smiles) + num_novel += 1 + return novel, num_novel / len(unique) + + def evaluate_rdmols(self, rdmols): + valid, validity = self.compute_validity(rdmols) + print(f"Validity over {len(rdmols)} molecules: {validity * 100 :.2f}%") + + connected, connectivity, connected_smiles = \ + self.compute_connectivity(valid) + print(f"Connectivity over {len(valid)} valid molecules: " + f"{connectivity * 100 :.2f}%") + + unique, uniqueness = self.compute_uniqueness(connected_smiles) + print(f"Uniqueness over {len(connected)} connected molecules: " + f"{uniqueness * 100 :.2f}%") + + _, novelty = self.compute_novelty(unique) + print(f"Novelty over {len(unique)} unique connected molecules: " + f"{novelty * 100 :.2f}%") + + return [validity, connectivity, uniqueness, novelty], [valid, connected] + + def evaluate(self, generated): + """ generated: list of pairs (positions: n x 3, atom_types: n [int]) + the positions and atom types should already be masked. """ + + rdmols = [build_molecule(*graph, self.dataset_info) + for graph in generated] + return self.evaluate_rdmols(rdmols) + + +class MoleculeProperties: + + @staticmethod + def calculate_qed(rdmol): + return QED.qed(rdmol) + + @staticmethod + def calculate_sa(rdmol): + sa = calculateScore(rdmol) + return round((10 - sa) / 9, 2) # from pocket2mol + + @staticmethod + def calculate_logp(rdmol): + return Crippen.MolLogP(rdmol) + + @staticmethod + def calculate_lipinski(rdmol): + rule_1 = Descriptors.ExactMolWt(rdmol) < 500 + rule_2 = Lipinski.NumHDonors(rdmol) <= 5 + rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10 + rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5) + rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10 + return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) + + @classmethod + def calculate_diversity(cls, pocket_mols): + if len(pocket_mols) < 2: + return 0.0 + + div = 0 + total = 0 + for i in range(len(pocket_mols)): + for j in range(i + 1, len(pocket_mols)): + div += 1 - cls.similarity(pocket_mols[i], pocket_mols[j]) + total += 1 + return div / total + + @staticmethod + def similarity(mol_a, mol_b): + # fp1 = AllChem.GetMorganFingerprintAsBitVect( + # mol_a, 2, nBits=2048, useChirality=False) + # fp2 = AllChem.GetMorganFingerprintAsBitVect( + # mol_b, 2, nBits=2048, useChirality=False) + fp1 = Chem.RDKFingerprint(mol_a) + fp2 = Chem.RDKFingerprint(mol_b) + return DataStructs.TanimotoSimilarity(fp1, fp2) + + def evaluate(self, pocket_rdmols): + """ + Run full evaluation + Args: + pocket_rdmols: list of lists, the inner list contains all RDKit + molecules generated for a pocket + Returns: + QED, SA, LogP, Lipinski (per molecule), and Diversity (per pocket) + """ + + for pocket in pocket_rdmols: + for mol in pocket: + Chem.SanitizeMol(mol) + assert mol is not None, "only evaluate valid molecules" + + all_qed = [] + all_sa = [] + all_logp = [] + all_lipinski = [] + per_pocket_diversity = [] + for pocket in tqdm(pocket_rdmols): + all_qed.append([self.calculate_qed(mol) for mol in pocket]) + all_sa.append([self.calculate_sa(mol) for mol in pocket]) + all_logp.append([self.calculate_logp(mol) for mol in pocket]) + all_lipinski.append([self.calculate_lipinski(mol) for mol in pocket]) + per_pocket_diversity.append(self.calculate_diversity(pocket)) + + print(f"{sum([len(p) for p in pocket_rdmols])} molecules from " + f"{len(pocket_rdmols)} pockets evaluated.") + + qed_flattened = [x for px in all_qed for x in px] + print(f"QED: {np.mean(qed_flattened):.3f} \pm {np.std(qed_flattened):.2f}") + + sa_flattened = [x for px in all_sa for x in px] + print(f"SA: {np.mean(sa_flattened):.3f} \pm {np.std(sa_flattened):.2f}") + + logp_flattened = [x for px in all_logp for x in px] + print(f"LogP: {np.mean(logp_flattened):.3f} \pm {np.std(logp_flattened):.2f}") + + lipinski_flattened = [x for px in all_lipinski for x in px] + print(f"Lipinski: {np.mean(lipinski_flattened):.3f} \pm {np.std(lipinski_flattened):.2f}") + + print(f"Diversity: {np.mean(per_pocket_diversity):.3f} \pm {np.std(per_pocket_diversity):.2f}") + + return all_qed, all_sa, all_logp, all_lipinski, per_pocket_diversity + + def evaluate_mean(self, rdmols): + """ + Run full evaluation and return mean of each property + Args: + rdmols: list of RDKit molecules + Returns: + QED, SA, LogP, Lipinski, and Diversity + """ + + if len(rdmols) < 1: + return 0.0, 0.0, 0.0, 0.0, 0.0 + + for mol in rdmols: + Chem.SanitizeMol(mol) + assert mol is not None, "only evaluate valid molecules" + + qed = np.mean([self.calculate_qed(mol) for mol in rdmols]) + sa = np.mean([self.calculate_sa(mol) for mol in rdmols]) + logp = np.mean([self.calculate_logp(mol) for mol in rdmols]) + lipinski = np.mean([self.calculate_lipinski(mol) for mol in rdmols]) + diversity = self.calculate_diversity(rdmols) + + return qed, sa, logp, lipinski, diversity