Diff of /analysis/metrics.py [000000] .. [607087]

Switch to unified view

a b/analysis/metrics.py
1
import numpy as np
2
from tqdm import tqdm
3
from rdkit import Chem, DataStructs
4
from rdkit.Chem import Descriptors, Crippen, Lipinski, QED
5
from analysis.SA_Score.sascorer import calculateScore
6
7
from analysis.molecule_builder import build_molecule
8
from copy import deepcopy
9
10
11
class CategoricalDistribution:
12
    EPS = 1e-10
13
14
    def __init__(self, histogram_dict, mapping):
15
        histogram = np.zeros(len(mapping))
16
        for k, v in histogram_dict.items():
17
            histogram[mapping[k]] = v
18
19
        # Normalize histogram
20
        self.p = histogram / histogram.sum()
21
        self.mapping = deepcopy(mapping)
22
23
    def kl_divergence(self, other_sample):
24
        sample_histogram = np.zeros(len(self.mapping))
25
        for x in other_sample:
26
            # sample_histogram[self.mapping[x]] += 1
27
            sample_histogram[x] += 1
28
29
        # Normalize
30
        q = sample_histogram / sample_histogram.sum()
31
32
        return -np.sum(self.p * np.log(q / self.p + self.EPS))
33
34
35
def rdmol_to_smiles(rdmol):
36
    mol = Chem.Mol(rdmol)
37
    Chem.RemoveStereochemistry(mol)
38
    mol = Chem.RemoveHs(mol)
39
    return Chem.MolToSmiles(mol)
40
41
42
class BasicMolecularMetrics(object):
43
    def __init__(self, dataset_info, dataset_smiles_list=None,
44
                 connectivity_thresh=1.0):
45
        self.atom_decoder = dataset_info['atom_decoder']
46
        if dataset_smiles_list is not None:
47
            dataset_smiles_list = set(dataset_smiles_list)
48
        self.dataset_smiles_list = dataset_smiles_list
49
        self.dataset_info = dataset_info
50
        self.connectivity_thresh = connectivity_thresh
51
52
    def compute_validity(self, generated):
53
        """ generated: list of couples (positions, atom_types)"""
54
        if len(generated) < 1:
55
            return [], 0.0
56
57
        valid = []
58
        for mol in generated:
59
            try:
60
                Chem.SanitizeMol(mol)
61
            except ValueError:
62
                continue
63
64
            valid.append(mol)
65
66
        return valid, len(valid) / len(generated)
67
68
    def compute_connectivity(self, valid):
69
        """ Consider molecule connected if its largest fragment contains at
70
        least x% of all atoms, where x is determined by
71
        self.connectivity_thresh (defaults to 100%). """
72
        if len(valid) < 1:
73
            return [], 0.0
74
75
        connected = []
76
        connected_smiles = []
77
        for mol in valid:
78
            mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True)
79
            largest_mol = \
80
                max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
81
            if largest_mol.GetNumAtoms() / mol.GetNumAtoms() >= self.connectivity_thresh:
82
                smiles = rdmol_to_smiles(largest_mol)
83
                if smiles is not None:
84
                    connected_smiles.append(smiles)
85
                    connected.append(largest_mol)
86
87
        return connected, len(connected_smiles) / len(valid), connected_smiles
88
89
    def compute_uniqueness(self, connected):
90
        """ valid: list of SMILES strings."""
91
        if len(connected) < 1 or self.dataset_smiles_list is None:
92
            return [], 0.0
93
94
        return list(set(connected)), len(set(connected)) / len(connected)
95
96
    def compute_novelty(self, unique):
97
        if len(unique) < 1:
98
            return [], 0.0
99
100
        num_novel = 0
101
        novel = []
102
        for smiles in unique:
103
            if smiles not in self.dataset_smiles_list:
104
                novel.append(smiles)
105
                num_novel += 1
106
        return novel, num_novel / len(unique)
107
108
    def evaluate_rdmols(self, rdmols):
109
        valid, validity = self.compute_validity(rdmols)
110
        print(f"Validity over {len(rdmols)} molecules: {validity * 100 :.2f}%")
111
112
        connected, connectivity, connected_smiles = \
113
            self.compute_connectivity(valid)
114
        print(f"Connectivity over {len(valid)} valid molecules: "
115
              f"{connectivity * 100 :.2f}%")
116
117
        unique, uniqueness = self.compute_uniqueness(connected_smiles)
118
        print(f"Uniqueness over {len(connected)} connected molecules: "
119
              f"{uniqueness * 100 :.2f}%")
120
121
        _, novelty = self.compute_novelty(unique)
122
        print(f"Novelty over {len(unique)} unique connected molecules: "
123
              f"{novelty * 100 :.2f}%")
124
125
        return [validity, connectivity, uniqueness, novelty], [valid, connected]
126
127
    def evaluate(self, generated):
128
        """ generated: list of pairs (positions: n x 3, atom_types: n [int])
129
            the positions and atom types should already be masked. """
130
131
        rdmols = [build_molecule(*graph, self.dataset_info)
132
                  for graph in generated]
133
        return self.evaluate_rdmols(rdmols)
134
135
136
class MoleculeProperties:
137
138
    @staticmethod
139
    def calculate_qed(rdmol):
140
        return QED.qed(rdmol)
141
142
    @staticmethod
143
    def calculate_sa(rdmol):
144
        sa = calculateScore(rdmol)
145
        return round((10 - sa) / 9, 2)  # from pocket2mol
146
147
    @staticmethod
148
    def calculate_logp(rdmol):
149
        return Crippen.MolLogP(rdmol)
150
151
    @staticmethod
152
    def calculate_lipinski(rdmol):
153
        rule_1 = Descriptors.ExactMolWt(rdmol) < 500
154
        rule_2 = Lipinski.NumHDonors(rdmol) <= 5
155
        rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10
156
        rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5)
157
        rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10
158
        return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
159
160
    @classmethod
161
    def calculate_diversity(cls, pocket_mols):
162
        if len(pocket_mols) < 2:
163
            return 0.0
164
165
        div = 0
166
        total = 0
167
        for i in range(len(pocket_mols)):
168
            for j in range(i + 1, len(pocket_mols)):
169
                div += 1 - cls.similarity(pocket_mols[i], pocket_mols[j])
170
                total += 1
171
        return div / total
172
173
    @staticmethod
174
    def similarity(mol_a, mol_b):
175
        # fp1 = AllChem.GetMorganFingerprintAsBitVect(
176
        #     mol_a, 2, nBits=2048, useChirality=False)
177
        # fp2 = AllChem.GetMorganFingerprintAsBitVect(
178
        #     mol_b, 2, nBits=2048, useChirality=False)
179
        fp1 = Chem.RDKFingerprint(mol_a)
180
        fp2 = Chem.RDKFingerprint(mol_b)
181
        return DataStructs.TanimotoSimilarity(fp1, fp2)
182
183
    def evaluate(self, pocket_rdmols):
184
        """
185
        Run full evaluation
186
        Args:
187
            pocket_rdmols: list of lists, the inner list contains all RDKit
188
                molecules generated for a pocket
189
        Returns:
190
            QED, SA, LogP, Lipinski (per molecule), and Diversity (per pocket)
191
        """
192
193
        for pocket in pocket_rdmols:
194
            for mol in pocket:
195
                Chem.SanitizeMol(mol)
196
                assert mol is not None, "only evaluate valid molecules"
197
198
        all_qed = []
199
        all_sa = []
200
        all_logp = []
201
        all_lipinski = []
202
        per_pocket_diversity = []
203
        for pocket in tqdm(pocket_rdmols):
204
            all_qed.append([self.calculate_qed(mol) for mol in pocket])
205
            all_sa.append([self.calculate_sa(mol) for mol in pocket])
206
            all_logp.append([self.calculate_logp(mol) for mol in pocket])
207
            all_lipinski.append([self.calculate_lipinski(mol) for mol in pocket])
208
            per_pocket_diversity.append(self.calculate_diversity(pocket))
209
210
        print(f"{sum([len(p) for p in pocket_rdmols])} molecules from "
211
              f"{len(pocket_rdmols)} pockets evaluated.")
212
213
        qed_flattened = [x for px in all_qed for x in px]
214
        print(f"QED: {np.mean(qed_flattened):.3f} \pm {np.std(qed_flattened):.2f}")
215
216
        sa_flattened = [x for px in all_sa for x in px]
217
        print(f"SA: {np.mean(sa_flattened):.3f} \pm {np.std(sa_flattened):.2f}")
218
219
        logp_flattened = [x for px in all_logp for x in px]
220
        print(f"LogP: {np.mean(logp_flattened):.3f} \pm {np.std(logp_flattened):.2f}")
221
222
        lipinski_flattened = [x for px in all_lipinski for x in px]
223
        print(f"Lipinski: {np.mean(lipinski_flattened):.3f} \pm {np.std(lipinski_flattened):.2f}")
224
225
        print(f"Diversity: {np.mean(per_pocket_diversity):.3f} \pm {np.std(per_pocket_diversity):.2f}")
226
227
        return all_qed, all_sa, all_logp, all_lipinski, per_pocket_diversity
228
229
    def evaluate_mean(self, rdmols):
230
        """
231
        Run full evaluation and return mean of each property
232
        Args:
233
            rdmols: list of RDKit molecules
234
        Returns:
235
            QED, SA, LogP, Lipinski, and Diversity
236
        """
237
238
        if len(rdmols) < 1:
239
            return 0.0, 0.0, 0.0, 0.0, 0.0
240
241
        for mol in rdmols:
242
            Chem.SanitizeMol(mol)
243
            assert mol is not None, "only evaluate valid molecules"
244
245
        qed = np.mean([self.calculate_qed(mol) for mol in rdmols])
246
        sa = np.mean([self.calculate_sa(mol) for mol in rdmols])
247
        logp = np.mean([self.calculate_logp(mol) for mol in rdmols])
248
        lipinski = np.mean([self.calculate_lipinski(mol) for mol in rdmols])
249
        diversity = self.calculate_diversity(rdmols)
250
251
        return qed, sa, logp, lipinski, diversity