import numpy as np
from tqdm import tqdm
from rdkit import Chem, DataStructs
from rdkit.Chem import Descriptors, Crippen, Lipinski, QED
from analysis.SA_Score.sascorer import calculateScore
from analysis.molecule_builder import build_molecule
from copy import deepcopy
class CategoricalDistribution:
EPS = 1e-10
def __init__(self, histogram_dict, mapping):
histogram = np.zeros(len(mapping))
for k, v in histogram_dict.items():
histogram[mapping[k]] = v
# Normalize histogram
self.p = histogram / histogram.sum()
self.mapping = deepcopy(mapping)
def kl_divergence(self, other_sample):
sample_histogram = np.zeros(len(self.mapping))
for x in other_sample:
# sample_histogram[self.mapping[x]] += 1
sample_histogram[x] += 1
# Normalize
q = sample_histogram / sample_histogram.sum()
return -np.sum(self.p * np.log(q / self.p + self.EPS))
def rdmol_to_smiles(rdmol):
mol = Chem.Mol(rdmol)
Chem.RemoveStereochemistry(mol)
mol = Chem.RemoveHs(mol)
return Chem.MolToSmiles(mol)
class BasicMolecularMetrics(object):
def __init__(self, dataset_info, dataset_smiles_list=None,
connectivity_thresh=1.0):
self.atom_decoder = dataset_info['atom_decoder']
if dataset_smiles_list is not None:
dataset_smiles_list = set(dataset_smiles_list)
self.dataset_smiles_list = dataset_smiles_list
self.dataset_info = dataset_info
self.connectivity_thresh = connectivity_thresh
def compute_validity(self, generated):
""" generated: list of couples (positions, atom_types)"""
if len(generated) < 1:
return [], 0.0
valid = []
for mol in generated:
try:
Chem.SanitizeMol(mol)
except ValueError:
continue
valid.append(mol)
return valid, len(valid) / len(generated)
def compute_connectivity(self, valid):
""" Consider molecule connected if its largest fragment contains at
least x% of all atoms, where x is determined by
self.connectivity_thresh (defaults to 100%). """
if len(valid) < 1:
return [], 0.0
connected = []
connected_smiles = []
for mol in valid:
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True)
largest_mol = \
max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
if largest_mol.GetNumAtoms() / mol.GetNumAtoms() >= self.connectivity_thresh:
smiles = rdmol_to_smiles(largest_mol)
if smiles is not None:
connected_smiles.append(smiles)
connected.append(largest_mol)
return connected, len(connected_smiles) / len(valid), connected_smiles
def compute_uniqueness(self, connected):
""" valid: list of SMILES strings."""
if len(connected) < 1 or self.dataset_smiles_list is None:
return [], 0.0
return list(set(connected)), len(set(connected)) / len(connected)
def compute_novelty(self, unique):
if len(unique) < 1:
return [], 0.0
num_novel = 0
novel = []
for smiles in unique:
if smiles not in self.dataset_smiles_list:
novel.append(smiles)
num_novel += 1
return novel, num_novel / len(unique)
def evaluate_rdmols(self, rdmols):
valid, validity = self.compute_validity(rdmols)
print(f"Validity over {len(rdmols)} molecules: {validity * 100 :.2f}%")
connected, connectivity, connected_smiles = \
self.compute_connectivity(valid)
print(f"Connectivity over {len(valid)} valid molecules: "
f"{connectivity * 100 :.2f}%")
unique, uniqueness = self.compute_uniqueness(connected_smiles)
print(f"Uniqueness over {len(connected)} connected molecules: "
f"{uniqueness * 100 :.2f}%")
_, novelty = self.compute_novelty(unique)
print(f"Novelty over {len(unique)} unique connected molecules: "
f"{novelty * 100 :.2f}%")
return [validity, connectivity, uniqueness, novelty], [valid, connected]
def evaluate(self, generated):
""" generated: list of pairs (positions: n x 3, atom_types: n [int])
the positions and atom types should already be masked. """
rdmols = [build_molecule(*graph, self.dataset_info)
for graph in generated]
return self.evaluate_rdmols(rdmols)
class MoleculeProperties:
@staticmethod
def calculate_qed(rdmol):
return QED.qed(rdmol)
@staticmethod
def calculate_sa(rdmol):
sa = calculateScore(rdmol)
return round((10 - sa) / 9, 2) # from pocket2mol
@staticmethod
def calculate_logp(rdmol):
return Crippen.MolLogP(rdmol)
@staticmethod
def calculate_lipinski(rdmol):
rule_1 = Descriptors.ExactMolWt(rdmol) < 500
rule_2 = Lipinski.NumHDonors(rdmol) <= 5
rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10
rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5)
rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10
return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
@classmethod
def calculate_diversity(cls, pocket_mols):
if len(pocket_mols) < 2:
return 0.0
div = 0
total = 0
for i in range(len(pocket_mols)):
for j in range(i + 1, len(pocket_mols)):
div += 1 - cls.similarity(pocket_mols[i], pocket_mols[j])
total += 1
return div / total
@staticmethod
def similarity(mol_a, mol_b):
# fp1 = AllChem.GetMorganFingerprintAsBitVect(
# mol_a, 2, nBits=2048, useChirality=False)
# fp2 = AllChem.GetMorganFingerprintAsBitVect(
# mol_b, 2, nBits=2048, useChirality=False)
fp1 = Chem.RDKFingerprint(mol_a)
fp2 = Chem.RDKFingerprint(mol_b)
return DataStructs.TanimotoSimilarity(fp1, fp2)
def evaluate(self, pocket_rdmols):
"""
Run full evaluation
Args:
pocket_rdmols: list of lists, the inner list contains all RDKit
molecules generated for a pocket
Returns:
QED, SA, LogP, Lipinski (per molecule), and Diversity (per pocket)
"""
for pocket in pocket_rdmols:
for mol in pocket:
Chem.SanitizeMol(mol)
assert mol is not None, "only evaluate valid molecules"
all_qed = []
all_sa = []
all_logp = []
all_lipinski = []
per_pocket_diversity = []
for pocket in tqdm(pocket_rdmols):
all_qed.append([self.calculate_qed(mol) for mol in pocket])
all_sa.append([self.calculate_sa(mol) for mol in pocket])
all_logp.append([self.calculate_logp(mol) for mol in pocket])
all_lipinski.append([self.calculate_lipinski(mol) for mol in pocket])
per_pocket_diversity.append(self.calculate_diversity(pocket))
print(f"{sum([len(p) for p in pocket_rdmols])} molecules from "
f"{len(pocket_rdmols)} pockets evaluated.")
qed_flattened = [x for px in all_qed for x in px]
print(f"QED: {np.mean(qed_flattened):.3f} \pm {np.std(qed_flattened):.2f}")
sa_flattened = [x for px in all_sa for x in px]
print(f"SA: {np.mean(sa_flattened):.3f} \pm {np.std(sa_flattened):.2f}")
logp_flattened = [x for px in all_logp for x in px]
print(f"LogP: {np.mean(logp_flattened):.3f} \pm {np.std(logp_flattened):.2f}")
lipinski_flattened = [x for px in all_lipinski for x in px]
print(f"Lipinski: {np.mean(lipinski_flattened):.3f} \pm {np.std(lipinski_flattened):.2f}")
print(f"Diversity: {np.mean(per_pocket_diversity):.3f} \pm {np.std(per_pocket_diversity):.2f}")
return all_qed, all_sa, all_logp, all_lipinski, per_pocket_diversity
def evaluate_mean(self, rdmols):
"""
Run full evaluation and return mean of each property
Args:
rdmols: list of RDKit molecules
Returns:
QED, SA, LogP, Lipinski, and Diversity
"""
if len(rdmols) < 1:
return 0.0, 0.0, 0.0, 0.0, 0.0
for mol in rdmols:
Chem.SanitizeMol(mol)
assert mol is not None, "only evaluate valid molecules"
qed = np.mean([self.calculate_qed(mol) for mol in rdmols])
sa = np.mean([self.calculate_sa(mol) for mol in rdmols])
logp = np.mean([self.calculate_logp(mol) for mol in rdmols])
lipinski = np.mean([self.calculate_lipinski(mol) for mol in rdmols])
diversity = self.calculate_diversity(rdmols)
return qed, sa, logp, lipinski, diversity