Switch to unified view

a b/bert_mixup/late_mixup/enumeration.py
1
"""
2
Script that performs Enumeration based augmentation for chemical SMILES
3
Implementation borrow from: https://github.com/EBjerrum/SMILES-enumeration
4
"""
5
from __future__ import print_function
6
from __future__ import division
7
from __future__ import unicode_literals
8
9
import os
10
import shutil
11
import numpy as np
12
import deepchem as dc
13
from deepchem.molnet import load_muv
14
from sklearn.ensemble import RandomForestClassifier
15
import pandas as pd
16
17
# Experimental Class for Smiles Enumeration, Iterator and SmilesIterator adapted from Keras 2.6.0
18
from rdkit import Chem
19
import threading
20
21
np.random.seed(123)
22
23
24
class Iterator(object):
25
    """Abstract base class for data iterators.
26
    # Arguments
27
        n: Integer, total number of samples in the dataset to loop over.
28
        batch_size: Integer, size of a batch.
29
        shuffle: Boolean, whether to shuffle the data between epochs.
30
        seed: Random seeding for data shuffling.
31
    """
32
33
    def __init__(self, n, batch_size, shuffle, seed):
34
        self.n = n
35
        self.batch_size = batch_size
36
        self.shuffle = shuffle
37
        self.batch_index = 0
38
        self.total_batches_seen = 0
39
        self.lock = threading.Lock()
40
        self.index_generator = self._flow_index(n, batch_size, shuffle, seed)
41
        if n < batch_size:
42
            raise ValueError(
43
                "Input data length is shorter than batch_size\nAdjust batch_size"
44
            )
45
46
    def reset(self):
47
        self.batch_index = 0
48
49
    def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
50
        # Ensure self.batch_index is 0.
51
        self.reset()
52
        while 1:
53
            if seed is not None:
54
                np.random.seed(seed + self.total_batches_seen)
55
            if self.batch_index == 0:
56
                index_array = np.arange(n)
57
                if shuffle:
58
                    index_array = np.random.permutation(n)
59
60
            current_index = (self.batch_index * batch_size) % n
61
            if n > current_index + batch_size:
62
                current_batch_size = batch_size
63
                self.batch_index += 1
64
            else:
65
                current_batch_size = n - current_index
66
                self.batch_index = 0
67
            self.total_batches_seen += 1
68
            yield (
69
                index_array[current_index : current_index + current_batch_size],
70
                current_index,
71
                current_batch_size,
72
            )
73
74
    def __iter__(self):
75
        # Needed if we want to do something like:
76
        # for x, y in data_gen.flow(...):
77
        return self
78
79
    def __next__(self, *args, **kwargs):
80
        return self.next(*args, **kwargs)
81
82
83
class SmilesIterator(Iterator):
84
    """Iterator yielding data from a SMILES array.
85
    # Arguments
86
        x: Numpy array of SMILES input data.
87
        y: Numpy array of targets data.
88
        smiles_data_generator: Instance of `SmilesEnumerator`
89
            to use for random SMILES generation.
90
        batch_size: Integer, size of a batch.
91
        shuffle: Boolean, whether to shuffle the data between epochs.
92
        seed: Random seed for data shuffling.
93
        dtype: dtype to use for returned batch. Set to keras.backend.floatx if using Keras
94
    """
95
96
    def __init__(
97
        self,
98
        x,
99
        y,
100
        smiles_data_generator,
101
        batch_size=32,
102
        shuffle=False,
103
        seed=None,
104
        dtype=np.float32,
105
    ):
106
        if y is not None and len(x) != len(y):
107
            raise ValueError(
108
                "X (images tensor) and y (labels) "
109
                "should have the same length. "
110
                "Found: X.shape = %s, y.shape = %s"
111
                % (np.asarray(x).shape, np.asarray(y).shape)
112
            )
113
114
        self.x = np.asarray(x)
115
116
        if y is not None:
117
            self.y = np.asarray(y)
118
        else:
119
            self.y = None
120
        self.smiles_data_generator = smiles_data_generator
121
        self.dtype = dtype
122
        super(SmilesIterator, self).__init__(x.shape[0], batch_size, shuffle, seed)
123
124
    def next(self):
125
        """For python 2.x.
126
        # Returns
127
            The next batch.
128
        """
129
        # Keeps under lock only the mechanism which advances
130
        # the indexing of each batch.
131
        with self.lock:
132
            index_array, current_index, current_batch_size = next(self.index_generator)
133
        # The transformation of images is not under thread lock
134
        # so it can be done in parallel
135
        batch_x = np.zeros(
136
            tuple(
137
                [current_batch_size]
138
                + [self.smiles_data_generator.pad, self.smiles_data_generator._charlen]
139
            ),
140
            dtype=self.dtype,
141
        )
142
        for i, j in enumerate(index_array):
143
            smiles = self.x[j : j + 1]
144
            x = self.smiles_data_generator.transform(smiles)
145
            batch_x[i] = x
146
147
        if self.y is None:
148
            return batch_x
149
        batch_y = self.y[index_array]
150
        return batch_x, batch_y
151
152
153
class SmilesEnumerator(object):
154
    """SMILES Enumerator, vectorizer and devectorizer
155
156
    #Arguments
157
        charset: string containing the characters for the vectorization
158
        can also be generated via the .fit() method
159
        pad: Length of the vectorization
160
        leftpad: Add spaces to the left of the SMILES
161
        isomericSmiles: Generate SMILES containing information about stereogenic centers
162
        enum: Enumerate the SMILES during transform
163
        canonical: use canonical SMILES during transform (overrides enum)
164
    """
165
166
    def __init__(
167
        self,
168
        charset="@C)(=cOn1S2/H[N]\\",
169
        pad=120,
170
        leftpad=True,
171
        isomericSmiles=True,
172
        enum=True,
173
        canonical=False,
174
    ):
175
        self._charset = None
176
        self.charset = charset
177
        self.pad = pad
178
        self.leftpad = leftpad
179
        self.isomericSmiles = isomericSmiles
180
        self.enumerate = enum
181
        self.canonical = canonical
182
183
    @property
184
    def charset(self):
185
        return self._charset
186
187
    @charset.setter
188
    def charset(self, charset):
189
        self._charset = charset
190
        self._charlen = len(charset)
191
        self._char_to_int = dict((c, i) for i, c in enumerate(charset))
192
        self._int_to_char = dict((i, c) for i, c in enumerate(charset))
193
194
    def fit(self, smiles, extra_chars=[], extra_pad=5):
195
        """Performs extraction of the charset and length of a SMILES datasets and sets self.pad and self.charset
196
197
        #Arguments
198
            smiles: Numpy array or Pandas series containing smiles as strings
199
            extra_chars: List of extra chars to add to the charset (e.g. "\\\\" when "/" is present)
200
            extra_pad: Extra padding to add before or after the SMILES vectorization
201
        """
202
        charset = set("".join(list(smiles)))
203
        self.charset = "".join(charset.union(set(extra_chars)))
204
        self.pad = max([len(smile) for smile in smiles]) + extra_pad
205
206
    def randomize_smiles(self, smiles):
207
        """Perform a randomization of a SMILES string
208
        must be RDKit sanitizable"""
209
        m = Chem.MolFromSmiles(smiles)
210
        ans = list(range(m.GetNumAtoms()))
211
        np.random.shuffle(ans)
212
        nm = Chem.RenumberAtoms(m, ans)
213
        return Chem.MolToSmiles(
214
            nm, canonical=self.canonical, isomericSmiles=self.isomericSmiles
215
        )
216
217
    def transform(self, smiles):
218
        """Perform an enumeration (randomization) and vectorization of a Numpy array of smiles strings
219
        #Arguments
220
            smiles: Numpy array or Pandas series containing smiles as strings
221
        """
222
        one_hot = np.zeros((smiles.shape[0], self.pad, self._charlen), dtype=np.int8)
223
        errors = 0
224
        if self.leftpad:
225
            for i, ss in enumerate(smiles):
226
                if self.enumerate:
227
                    ss = self.randomize_smiles(ss)
228
                l = len(ss)
229
                diff = self.pad - l
230
                for j, c in enumerate(ss):
231
                    try:
232
                        one_hot[i, j + diff, self._char_to_int[c]] = 1
233
                    except:
234
                        errors += 1
235
                        break
236
            # print(f"errors: {errors}")
237
            return one_hot
238
        else:
239
            for i, ss in enumerate(smiles):
240
                if self.enumerate:
241
                    ss = self.randomize_smiles(ss)
242
                for j, c in enumerate(ss):
243
                    try:
244
                        one_hot[i, j, self._char_to_int[c]] = 1
245
                    except:
246
                        errors += 1
247
                        break
248
            # print(f"errors: {errors}")
249
            return one_hot
250
251
    def reverse_transform(self, vect):
252
        """Performs a conversion of a vectorized SMILES to a smiles strings
253
        charset must be the same as used for vectorization.
254
        #Arguments
255
            vect: Numpy array of vectorized SMILES.
256
        """
257
        smiles = []
258
        for v in vect:
259
            # mask v
260
            v = v[v.sum(axis=1) == 1]
261
            # Find one hot encoded index with argmax, translate to char and join to string
262
            smile = "".join(self._int_to_char[i] for i in v.argmax(axis=1))
263
            smiles.append(smile)
264
        return np.array(smiles)
265
266
    def enumerate_smiles(
267
        self,
268
        data_reader,
269
        smiles_col,
270
        replication_count=2,
271
        random_pairs=False,
272
        rand_proba=0.0,
273
    ):
274
        """
275
        Performs enumeration augmentation on the canonical molecular SMILES
276
277
        Args:
278
            dataset (_type_): dataframe containing molecular SMILSS
279
            smiles_col (_type_): column corresponding to molecular SMILES
280
            replication_count (int, optional): Number of enumerations for each CHEMICAL SMILE. Defaults to 2.
281
        """
282
        smiles = np.repeat(data_reader.dataset[smiles_col].values, replication_count)
283
        self.fit(smiles, extra_chars=["\\"])
284
        v = self.transform(smiles)
285
        transformed = self.reverse_transform(v)
286
287
        # print(len(v), len(original_smiles), len(transformed))
288
        is_enumerated = [1] * len(smiles)
289
        if random_pairs:
290
            assert len(smiles) == len(
291
                transformed
292
            ), "The length of augmented SMILES must equal original SMILES"
293
            for idx, _ in enumerate(smiles):
294
                if round(np.random.uniform(), 1) > rand_proba:
295
                    continue
296
                else:
297
                    transformed[idx] = np.random.choice(smiles)
298
                    is_enumerated[idx] = 0
299
        return transformed, list(smiles), is_enumerated
300
301
    def enumerate_smiles_df(
302
        self,
303
        data_reader,
304
        smiles_col,
305
        replication_count=2,
306
        random_pairs=False,
307
        rand_proba=0.0,
308
    ):
309
        """
310
        Performs enumeration augmentation on the canonical molecular SMILES
311
312
        Args:
313
            dataset (_type_): dataframe containing molecular SMILSS
314
            smiles_col (_type_): column corresponding to molecular SMILES
315
            replication_count (int, optional): Number of enumerations for each CHEMICAL SMILE. Defaults to 2.
316
        """
317
        smiles = np.repeat(data_reader[smiles_col].values, replication_count)
318
        self.fit(smiles, extra_chars=["\\"])
319
        v = self.transform(smiles)
320
        transformed = self.reverse_transform(v)
321
322
        # print(len(v), len(original_smiles), len(transformed))
323
        is_enumerated = [1] * len(smiles)
324
        if random_pairs:
325
            assert len(smiles) == len(
326
                transformed
327
            ), "The length of augmented SMILES must equal original SMILES"
328
            for idx, _ in enumerate(smiles):
329
                if round(np.random.uniform(), 1) > rand_proba:
330
                    continue
331
                else:
332
                    transformed[idx] = np.random.choice(smiles)
333
                    is_enumerated[idx] = 0
334
        return transformed, is_enumerated
335
336
    def smiles_enumeration(self, input_smiles, replication_count=100, n_augment=0):
337
        """
338
        Performs enumeration augmentation on the canonical molecular SMILES
339
340
        Args:
341
            dataset (_type_): dataframe containing molecular SMILSS
342
            smiles_col (_type_): column corresponding to molecular SMILES
343
            replication_count (int, optional): Number of enumerations for each CHEMICAL SMILE. Defaults to 2.
344
        """
345
        enumerations = []
346
        try:
347
            smiles = np.repeat([input_smiles], replication_count)
348
            self.fit(smiles, extra_chars=["\\"])
349
            v = self.transform(smiles)
350
            transformed = self.reverse_transform(v)
351
352
            for _, enumerated_smiles in enumerate(transformed):
353
                if len(enumerated_smiles) >= len(input_smiles):
354
                    enumerations.append(enumerated_smiles)
355
                if len(enumerations) >= n_augment:
356
                    break
357
        except:
358
            pass
359
        return enumerations