|
a |
|
b/analysis/metrics.py |
|
|
1 |
import numpy as np |
|
|
2 |
from tqdm import tqdm |
|
|
3 |
from rdkit import Chem, DataStructs |
|
|
4 |
from rdkit.Chem import Descriptors, Crippen, Lipinski, QED |
|
|
5 |
from analysis.SA_Score.sascorer import calculateScore |
|
|
6 |
|
|
|
7 |
from analysis.molecule_builder import build_molecule |
|
|
8 |
from copy import deepcopy |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class CategoricalDistribution: |
|
|
12 |
EPS = 1e-10 |
|
|
13 |
|
|
|
14 |
def __init__(self, histogram_dict, mapping): |
|
|
15 |
histogram = np.zeros(len(mapping)) |
|
|
16 |
for k, v in histogram_dict.items(): |
|
|
17 |
histogram[mapping[k]] = v |
|
|
18 |
|
|
|
19 |
# Normalize histogram |
|
|
20 |
self.p = histogram / histogram.sum() |
|
|
21 |
self.mapping = deepcopy(mapping) |
|
|
22 |
|
|
|
23 |
def kl_divergence(self, other_sample): |
|
|
24 |
sample_histogram = np.zeros(len(self.mapping)) |
|
|
25 |
for x in other_sample: |
|
|
26 |
# sample_histogram[self.mapping[x]] += 1 |
|
|
27 |
sample_histogram[x] += 1 |
|
|
28 |
|
|
|
29 |
# Normalize |
|
|
30 |
q = sample_histogram / sample_histogram.sum() |
|
|
31 |
|
|
|
32 |
return -np.sum(self.p * np.log(q / self.p + self.EPS)) |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
def rdmol_to_smiles(rdmol): |
|
|
36 |
mol = Chem.Mol(rdmol) |
|
|
37 |
Chem.RemoveStereochemistry(mol) |
|
|
38 |
mol = Chem.RemoveHs(mol) |
|
|
39 |
return Chem.MolToSmiles(mol) |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
class BasicMolecularMetrics(object): |
|
|
43 |
def __init__(self, dataset_info, dataset_smiles_list=None, |
|
|
44 |
connectivity_thresh=1.0): |
|
|
45 |
self.atom_decoder = dataset_info['atom_decoder'] |
|
|
46 |
if dataset_smiles_list is not None: |
|
|
47 |
dataset_smiles_list = set(dataset_smiles_list) |
|
|
48 |
self.dataset_smiles_list = dataset_smiles_list |
|
|
49 |
self.dataset_info = dataset_info |
|
|
50 |
self.connectivity_thresh = connectivity_thresh |
|
|
51 |
|
|
|
52 |
def compute_validity(self, generated): |
|
|
53 |
""" generated: list of couples (positions, atom_types)""" |
|
|
54 |
if len(generated) < 1: |
|
|
55 |
return [], 0.0 |
|
|
56 |
|
|
|
57 |
valid = [] |
|
|
58 |
for mol in generated: |
|
|
59 |
try: |
|
|
60 |
Chem.SanitizeMol(mol) |
|
|
61 |
except ValueError: |
|
|
62 |
continue |
|
|
63 |
|
|
|
64 |
valid.append(mol) |
|
|
65 |
|
|
|
66 |
return valid, len(valid) / len(generated) |
|
|
67 |
|
|
|
68 |
def compute_connectivity(self, valid): |
|
|
69 |
""" Consider molecule connected if its largest fragment contains at |
|
|
70 |
least x% of all atoms, where x is determined by |
|
|
71 |
self.connectivity_thresh (defaults to 100%). """ |
|
|
72 |
if len(valid) < 1: |
|
|
73 |
return [], 0.0 |
|
|
74 |
|
|
|
75 |
connected = [] |
|
|
76 |
connected_smiles = [] |
|
|
77 |
for mol in valid: |
|
|
78 |
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True) |
|
|
79 |
largest_mol = \ |
|
|
80 |
max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) |
|
|
81 |
if largest_mol.GetNumAtoms() / mol.GetNumAtoms() >= self.connectivity_thresh: |
|
|
82 |
smiles = rdmol_to_smiles(largest_mol) |
|
|
83 |
if smiles is not None: |
|
|
84 |
connected_smiles.append(smiles) |
|
|
85 |
connected.append(largest_mol) |
|
|
86 |
|
|
|
87 |
return connected, len(connected_smiles) / len(valid), connected_smiles |
|
|
88 |
|
|
|
89 |
def compute_uniqueness(self, connected): |
|
|
90 |
""" valid: list of SMILES strings.""" |
|
|
91 |
if len(connected) < 1 or self.dataset_smiles_list is None: |
|
|
92 |
return [], 0.0 |
|
|
93 |
|
|
|
94 |
return list(set(connected)), len(set(connected)) / len(connected) |
|
|
95 |
|
|
|
96 |
def compute_novelty(self, unique): |
|
|
97 |
if len(unique) < 1: |
|
|
98 |
return [], 0.0 |
|
|
99 |
|
|
|
100 |
num_novel = 0 |
|
|
101 |
novel = [] |
|
|
102 |
for smiles in unique: |
|
|
103 |
if smiles not in self.dataset_smiles_list: |
|
|
104 |
novel.append(smiles) |
|
|
105 |
num_novel += 1 |
|
|
106 |
return novel, num_novel / len(unique) |
|
|
107 |
|
|
|
108 |
def evaluate_rdmols(self, rdmols): |
|
|
109 |
valid, validity = self.compute_validity(rdmols) |
|
|
110 |
print(f"Validity over {len(rdmols)} molecules: {validity * 100 :.2f}%") |
|
|
111 |
|
|
|
112 |
connected, connectivity, connected_smiles = \ |
|
|
113 |
self.compute_connectivity(valid) |
|
|
114 |
print(f"Connectivity over {len(valid)} valid molecules: " |
|
|
115 |
f"{connectivity * 100 :.2f}%") |
|
|
116 |
|
|
|
117 |
unique, uniqueness = self.compute_uniqueness(connected_smiles) |
|
|
118 |
print(f"Uniqueness over {len(connected)} connected molecules: " |
|
|
119 |
f"{uniqueness * 100 :.2f}%") |
|
|
120 |
|
|
|
121 |
_, novelty = self.compute_novelty(unique) |
|
|
122 |
print(f"Novelty over {len(unique)} unique connected molecules: " |
|
|
123 |
f"{novelty * 100 :.2f}%") |
|
|
124 |
|
|
|
125 |
return [validity, connectivity, uniqueness, novelty], [valid, connected] |
|
|
126 |
|
|
|
127 |
def evaluate(self, generated): |
|
|
128 |
""" generated: list of pairs (positions: n x 3, atom_types: n [int]) |
|
|
129 |
the positions and atom types should already be masked. """ |
|
|
130 |
|
|
|
131 |
rdmols = [build_molecule(*graph, self.dataset_info) |
|
|
132 |
for graph in generated] |
|
|
133 |
return self.evaluate_rdmols(rdmols) |
|
|
134 |
|
|
|
135 |
|
|
|
136 |
class MoleculeProperties: |
|
|
137 |
|
|
|
138 |
@staticmethod |
|
|
139 |
def calculate_qed(rdmol): |
|
|
140 |
return QED.qed(rdmol) |
|
|
141 |
|
|
|
142 |
@staticmethod |
|
|
143 |
def calculate_sa(rdmol): |
|
|
144 |
sa = calculateScore(rdmol) |
|
|
145 |
return round((10 - sa) / 9, 2) # from pocket2mol |
|
|
146 |
|
|
|
147 |
@staticmethod |
|
|
148 |
def calculate_logp(rdmol): |
|
|
149 |
return Crippen.MolLogP(rdmol) |
|
|
150 |
|
|
|
151 |
@staticmethod |
|
|
152 |
def calculate_lipinski(rdmol): |
|
|
153 |
rule_1 = Descriptors.ExactMolWt(rdmol) < 500 |
|
|
154 |
rule_2 = Lipinski.NumHDonors(rdmol) <= 5 |
|
|
155 |
rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10 |
|
|
156 |
rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5) |
|
|
157 |
rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10 |
|
|
158 |
return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) |
|
|
159 |
|
|
|
160 |
@classmethod |
|
|
161 |
def calculate_diversity(cls, pocket_mols): |
|
|
162 |
if len(pocket_mols) < 2: |
|
|
163 |
return 0.0 |
|
|
164 |
|
|
|
165 |
div = 0 |
|
|
166 |
total = 0 |
|
|
167 |
for i in range(len(pocket_mols)): |
|
|
168 |
for j in range(i + 1, len(pocket_mols)): |
|
|
169 |
div += 1 - cls.similarity(pocket_mols[i], pocket_mols[j]) |
|
|
170 |
total += 1 |
|
|
171 |
return div / total |
|
|
172 |
|
|
|
173 |
@staticmethod |
|
|
174 |
def similarity(mol_a, mol_b): |
|
|
175 |
# fp1 = AllChem.GetMorganFingerprintAsBitVect( |
|
|
176 |
# mol_a, 2, nBits=2048, useChirality=False) |
|
|
177 |
# fp2 = AllChem.GetMorganFingerprintAsBitVect( |
|
|
178 |
# mol_b, 2, nBits=2048, useChirality=False) |
|
|
179 |
fp1 = Chem.RDKFingerprint(mol_a) |
|
|
180 |
fp2 = Chem.RDKFingerprint(mol_b) |
|
|
181 |
return DataStructs.TanimotoSimilarity(fp1, fp2) |
|
|
182 |
|
|
|
183 |
def evaluate(self, pocket_rdmols): |
|
|
184 |
""" |
|
|
185 |
Run full evaluation |
|
|
186 |
Args: |
|
|
187 |
pocket_rdmols: list of lists, the inner list contains all RDKit |
|
|
188 |
molecules generated for a pocket |
|
|
189 |
Returns: |
|
|
190 |
QED, SA, LogP, Lipinski (per molecule), and Diversity (per pocket) |
|
|
191 |
""" |
|
|
192 |
|
|
|
193 |
for pocket in pocket_rdmols: |
|
|
194 |
for mol in pocket: |
|
|
195 |
Chem.SanitizeMol(mol) |
|
|
196 |
assert mol is not None, "only evaluate valid molecules" |
|
|
197 |
|
|
|
198 |
all_qed = [] |
|
|
199 |
all_sa = [] |
|
|
200 |
all_logp = [] |
|
|
201 |
all_lipinski = [] |
|
|
202 |
per_pocket_diversity = [] |
|
|
203 |
for pocket in tqdm(pocket_rdmols): |
|
|
204 |
all_qed.append([self.calculate_qed(mol) for mol in pocket]) |
|
|
205 |
all_sa.append([self.calculate_sa(mol) for mol in pocket]) |
|
|
206 |
all_logp.append([self.calculate_logp(mol) for mol in pocket]) |
|
|
207 |
all_lipinski.append([self.calculate_lipinski(mol) for mol in pocket]) |
|
|
208 |
per_pocket_diversity.append(self.calculate_diversity(pocket)) |
|
|
209 |
|
|
|
210 |
print(f"{sum([len(p) for p in pocket_rdmols])} molecules from " |
|
|
211 |
f"{len(pocket_rdmols)} pockets evaluated.") |
|
|
212 |
|
|
|
213 |
qed_flattened = [x for px in all_qed for x in px] |
|
|
214 |
print(f"QED: {np.mean(qed_flattened):.3f} \pm {np.std(qed_flattened):.2f}") |
|
|
215 |
|
|
|
216 |
sa_flattened = [x for px in all_sa for x in px] |
|
|
217 |
print(f"SA: {np.mean(sa_flattened):.3f} \pm {np.std(sa_flattened):.2f}") |
|
|
218 |
|
|
|
219 |
logp_flattened = [x for px in all_logp for x in px] |
|
|
220 |
print(f"LogP: {np.mean(logp_flattened):.3f} \pm {np.std(logp_flattened):.2f}") |
|
|
221 |
|
|
|
222 |
lipinski_flattened = [x for px in all_lipinski for x in px] |
|
|
223 |
print(f"Lipinski: {np.mean(lipinski_flattened):.3f} \pm {np.std(lipinski_flattened):.2f}") |
|
|
224 |
|
|
|
225 |
print(f"Diversity: {np.mean(per_pocket_diversity):.3f} \pm {np.std(per_pocket_diversity):.2f}") |
|
|
226 |
|
|
|
227 |
return all_qed, all_sa, all_logp, all_lipinski, per_pocket_diversity |
|
|
228 |
|
|
|
229 |
def evaluate_mean(self, rdmols): |
|
|
230 |
""" |
|
|
231 |
Run full evaluation and return mean of each property |
|
|
232 |
Args: |
|
|
233 |
rdmols: list of RDKit molecules |
|
|
234 |
Returns: |
|
|
235 |
QED, SA, LogP, Lipinski, and Diversity |
|
|
236 |
""" |
|
|
237 |
|
|
|
238 |
if len(rdmols) < 1: |
|
|
239 |
return 0.0, 0.0, 0.0, 0.0, 0.0 |
|
|
240 |
|
|
|
241 |
for mol in rdmols: |
|
|
242 |
Chem.SanitizeMol(mol) |
|
|
243 |
assert mol is not None, "only evaluate valid molecules" |
|
|
244 |
|
|
|
245 |
qed = np.mean([self.calculate_qed(mol) for mol in rdmols]) |
|
|
246 |
sa = np.mean([self.calculate_sa(mol) for mol in rdmols]) |
|
|
247 |
logp = np.mean([self.calculate_logp(mol) for mol in rdmols]) |
|
|
248 |
lipinski = np.mean([self.calculate_lipinski(mol) for mol in rdmols]) |
|
|
249 |
diversity = self.calculate_diversity(rdmols) |
|
|
250 |
|
|
|
251 |
return qed, sa, logp, lipinski, diversity |