a b/torchdrug/data/feature.py
1
import warnings
2
3
from rdkit import Chem
4
from rdkit.Chem import AllChem
5
6
from torchdrug.core import Registry as R
7
8
9
# orderd by perodic table
10
atom_vocab = ["H", "B", "C", "N", "O", "F", "Mg", "Si", "P", "S", "Cl", "Cu", "Zn", "Se", "Br", "Sn", "I"]
11
atom_vocab = {a: i for i, a in enumerate(atom_vocab)}
12
degree_vocab = range(7)
13
num_hs_vocab = range(7)
14
formal_charge_vocab = range(-5, 6)
15
chiral_tag_vocab = range(4)
16
total_valence_vocab = range(8)
17
num_radical_vocab = range(8)
18
hybridization_vocab = range(len(Chem.rdchem.HybridizationType.values))
19
20
bond_type_vocab = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
21
                   Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
22
bond_type_vocab = {b: i for i, b in enumerate(bond_type_vocab)}
23
bond_dir_vocab = range(len(Chem.rdchem.BondDir.values))
24
bond_stereo_vocab = range(len(Chem.rdchem.BondStereo.values))
25
26
# orderd by molecular mass
27
residue_vocab = ["GLY", "ALA", "SER", "PRO", "VAL", "THR", "CYS", "ILE", "LEU", "ASN",
28
                 "ASP", "GLN", "LYS", "GLU", "MET", "HIS", "PHE", "ARG", "TYR", "TRP"]
29
30
31
def onehot(x, vocab, allow_unknown=False):
32
    if x in vocab:
33
        if isinstance(vocab, dict):
34
            index = vocab[x]
35
        else:
36
            index = vocab.index(x)
37
    else:
38
        index = -1
39
    if allow_unknown:
40
        feature = [0] * (len(vocab) + 1)
41
        if index == -1:
42
            warnings.warn("Unknown value `%s`" % x)
43
        feature[index] = 1
44
    else:
45
        feature = [0] * len(vocab)
46
        if index == -1:
47
            raise ValueError("Unknown value `%s`. Available vocabulary is `%s`" % (x, vocab))
48
        feature[index] = 1
49
50
    return feature
51
52
53
# TODO: this one is too slow
54
@R.register("features.atom.default")
55
def atom_default(atom):
56
    """Default atom feature.
57
58
    Features:
59
        GetSymbol(): one-hot embedding for the atomic symbol
60
        
61
        GetChiralTag(): one-hot embedding for atomic chiral tag 
62
        
63
        GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs
64
        
65
        GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule
66
        
67
        GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom 
68
        
69
        GetNumRadicalElectrons(): one-hot embedding for the number of radical electrons on the atom
70
        
71
        GetHybridization(): one-hot embedding for the atom's hybridization
72
        
73
        GetIsAromatic(): whether the atom is aromatic
74
        
75
        IsInRing(): whether the atom is in a ring
76
    """
77
    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
78
           onehot(atom.GetChiralTag(), chiral_tag_vocab) + \
79
           onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \
80
           onehot(atom.GetFormalCharge(), formal_charge_vocab) + \
81
           onehot(atom.GetTotalNumHs(), num_hs_vocab) + \
82
           onehot(atom.GetNumRadicalElectrons(), num_radical_vocab) + \
83
           onehot(atom.GetHybridization(), hybridization_vocab) + \
84
           [atom.GetIsAromatic(), atom.IsInRing()]
85
86
87
@R.register("features.atom.center_identification")
88
def atom_center_identification(atom):
89
    """Reaction center identification atom feature.
90
91
    Features:
92
        GetSymbol(): one-hot embedding for the atomic symbol
93
        
94
        GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom 
95
        
96
        GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs
97
        
98
        GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom
99
        
100
        GetIsAromatic(): whether the atom is aromatic
101
        
102
        IsInRing(): whether the atom is in a ring
103
    """
104
    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
105
           onehot(atom.GetTotalNumHs(), num_hs_vocab) + \
106
           onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \
107
           onehot(atom.GetTotalValence(), total_valence_vocab) + \
108
           [atom.GetIsAromatic(), atom.IsInRing()]
109
110
111
@R.register("features.atom.synthon_completion")
112
def atom_synthon_completion(atom):
113
    """Synthon completion atom feature.
114
115
    Features:
116
        GetSymbol(): one-hot embedding for the atomic symbol
117
118
        GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom 
119
        
120
        GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs
121
        
122
        IsInRing(): whether the atom is in a ring
123
        
124
        IsInRingSize(3, 4, 5, 6): whether the atom is in a ring of a particular size
125
        
126
        IsInRing() and not IsInRingSize(3, 4, 5, 6): whether the atom is in a ring and not in a ring of 3, 4, 5, 6
127
    """
128
    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
129
           onehot(atom.GetTotalNumHs(), num_hs_vocab) + \
130
           onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \
131
           [atom.IsInRing(), atom.IsInRingSize(3), atom.IsInRingSize(4),
132
            atom.IsInRingSize(5), atom.IsInRingSize(6), 
133
            atom.IsInRing() and (not atom.IsInRingSize(3)) and (not atom.IsInRingSize(4)) \
134
            and (not atom.IsInRingSize(5)) and (not atom.IsInRingSize(6))]
135
136
137
@R.register("features.atom.symbol")
138
def atom_symbol(atom):
139
    """Symbol atom feature.
140
141
    Features:
142
        GetSymbol(): one-hot embedding for the atomic symbol
143
    """
144
    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True)
145
146
147
@R.register("features.atom.explicit_property_prediction")
148
def atom_explicit_property_prediction(atom):
149
    """Explicit property prediction atom feature.
150
151
    Features:
152
        GetSymbol(): one-hot embedding for the atomic symbol
153
154
        GetDegree(): one-hot embedding for the degree of the atom in the molecule
155
156
        GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom
157
        
158
        GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule
159
        
160
        GetIsAromatic(): whether the atom is aromatic
161
    """
162
    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
163
           onehot(atom.GetDegree(), degree_vocab, allow_unknown=True) + \
164
           onehot(atom.GetTotalValence(), total_valence_vocab, allow_unknown=True) + \
165
           onehot(atom.GetFormalCharge(), formal_charge_vocab) + \
166
           [atom.GetIsAromatic()]
167
168
169
@R.register("features.atom.property_prediction")
170
def atom_property_prediction(atom):
171
    """Property prediction atom feature.
172
173
    Features:
174
        GetSymbol(): one-hot embedding for the atomic symbol
175
        
176
        GetDegree(): one-hot embedding for the degree of the atom in the molecule
177
        
178
        GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom 
179
        
180
        GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom
181
        
182
        GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule
183
        
184
        GetIsAromatic(): whether the atom is aromatic
185
    """
186
    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
187
           onehot(atom.GetDegree(), degree_vocab, allow_unknown=True) + \
188
           onehot(atom.GetTotalNumHs(), num_hs_vocab, allow_unknown=True) + \
189
           onehot(atom.GetTotalValence(), total_valence_vocab, allow_unknown=True) + \
190
           onehot(atom.GetFormalCharge(), formal_charge_vocab, allow_unknown=True) + \
191
           [atom.GetIsAromatic()]
192
193
194
@R.register("features.atom.position")
195
def atom_position(atom):
196
    """
197
    Atom position in the molecular conformation.
198
    Return 3D position if available, otherwise 2D position is returned.
199
200
    Note it takes much time to compute the conformation for large molecules.
201
    """
202
    mol = atom.GetOwningMol()
203
    if mol.GetNumConformers() == 0:
204
        mol.Compute2DCoords()
205
    conformer = mol.GetConformer()
206
    pos = conformer.GetAtomPosition(atom.GetIdx())
207
    return [pos.x, pos.y, pos.z]
208
209
210
@R.register("features.atom.pretrain")
211
def atom_pretrain(atom):
212
    """Atom feature for pretraining.
213
214
    Features:
215
        GetSymbol(): one-hot embedding for the atomic symbol
216
        
217
        GetChiralTag(): one-hot embedding for atomic chiral tag
218
    """
219
    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
220
           onehot(atom.GetChiralTag(), chiral_tag_vocab)
221
222
223
@R.register("features.atom.residue_symbol")
224
def atom_residue_symbol(atom):
225
    """Residue symbol as atom feature. Only support atoms in a protein.
226
227
    Features:
228
        GetSymbol(): one-hot embedding for the atomic symbol
229
        GetResidueName(): one-hot embedding for the residue symbol
230
    """
231
    residue = atom.GetPDBResidueInfo()
232
    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
233
           onehot(residue.GetResidueName() if residue else -1, residue_vocab, allow_unknown=True)
234
235
236
@R.register("features.bond.default")
237
def bond_default(bond):
238
    """Default bond feature.
239
240
    Features:
241
        GetBondType(): one-hot embedding for the type of the bond
242
        
243
        GetBondDir(): one-hot embedding for the direction of the bond
244
        
245
        GetStereo(): one-hot embedding for the stereo configuration of the bond
246
        
247
        GetIsConjugated(): whether the bond is considered to be conjugated
248
    """
249
    return onehot(bond.GetBondType(), bond_type_vocab) + \
250
           onehot(bond.GetBondDir(), bond_dir_vocab) + \
251
           onehot(bond.GetStereo(), bond_stereo_vocab) + \
252
           [int(bond.GetIsConjugated())]
253
254
255
@R.register("features.bond.length")
256
def bond_length(bond):
257
    """
258
    Bond length in the molecular conformation.
259
260
    Note it takes much time to compute the conformation for large molecules.
261
    """
262
    mol = bond.GetOwningMol()
263
    if mol.GetNumConformers() == 0:
264
        mol.Compute2DCoords()
265
    conformer = mol.GetConformer()
266
    h = conformer.GetAtomPosition(bond.GetBeginAtomIdx())
267
    t = conformer.GetAtomPosition(bond.GetEndAtomIdx())
268
    return [h.Distance(t)]
269
270
271
@R.register("features.bond.property_prediction")
272
def bond_property_prediction(bond):
273
    """Property prediction bond feature.
274
275
    Features:
276
        GetBondType(): one-hot embedding for the type of the bond
277
        
278
        GetIsConjugated(): whether the bond is considered to be conjugated
279
        
280
        IsInRing(): whether the bond is in a ring
281
    """
282
    return onehot(bond.GetBondType(), bond_type_vocab) + \
283
           [int(bond.GetIsConjugated()), bond.IsInRing()]
284
285
286
@R.register("features.bond.pretrain")
287
def bond_pretrain(bond):
288
    """Bond feature for pretraining.
289
290
    Features:
291
        GetBondType(): one-hot embedding for the type of the bond
292
        
293
        GetBondDir(): one-hot embedding for the direction of the bond
294
    """
295
    return onehot(bond.GetBondType(), bond_type_vocab) + \
296
           onehot(bond.GetBondDir(), bond_dir_vocab)
297
298
299
@R.register("features.residue.symbol")
300
def residue_symbol(residue):
301
    """Symbol residue feature.
302
303
    Features:
304
        GetResidueName(): one-hot embedding for the residue symbol
305
    """
306
    return onehot(residue.GetResidueName(), residue_vocab, allow_unknown=True)
307
308
309
@R.register("features.residue.default")
310
def residue_default(residue):
311
    """Default residue feature.
312
313
    Features:
314
        GetResidueName(): one-hot embedding for the residue symbol
315
    """
316
    return residue_symbol(residue)
317
318
319
@R.register("features.molecule.ecfp")
320
def ExtendedConnectivityFingerprint(mol, radius=2, length=1024):
321
    """Extended Connectivity Fingerprint molecule feature.
322
323
    Features:
324
        GetMorganFingerprintAsBitVect(): a Morgan fingerprint for a molecule as a bit vector
325
    """
326
    ecfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, length)
327
    return list(ecfp)
328
329
330
@R.register("features.molecule.default")
331
def molecule_default(mol):
332
    """Default molecule feature."""
333
    return ExtendedConnectivityFingerprint(mol)
334
335
336
ECFP = ExtendedConnectivityFingerprint
337
338
339
__all__ = [
340
    "atom_default", "atom_center_identification", "atom_synthon_completion",
341
    "atom_symbol", "atom_explicit_property_prediction", "atom_property_prediction",
342
    "atom_position", "atom_pretrain", "atom_residue_symbol",
343
    "bond_default", "bond_length", "bond_property_prediction", "bond_pretrain",
344
    "residue_symbol", "residue_default",
345
    "ExtendedConnectivityFingerprint", "molecule_default",
346
    "ECFP",
347
]