[36b44b]: / torchdrug / data / feature.py

Download this file

347 lines (250 with data), 12.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import warnings
from rdkit import Chem
from rdkit.Chem import AllChem
from torchdrug.core import Registry as R
# orderd by perodic table
atom_vocab = ["H", "B", "C", "N", "O", "F", "Mg", "Si", "P", "S", "Cl", "Cu", "Zn", "Se", "Br", "Sn", "I"]
atom_vocab = {a: i for i, a in enumerate(atom_vocab)}
degree_vocab = range(7)
num_hs_vocab = range(7)
formal_charge_vocab = range(-5, 6)
chiral_tag_vocab = range(4)
total_valence_vocab = range(8)
num_radical_vocab = range(8)
hybridization_vocab = range(len(Chem.rdchem.HybridizationType.values))
bond_type_vocab = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
bond_type_vocab = {b: i for i, b in enumerate(bond_type_vocab)}
bond_dir_vocab = range(len(Chem.rdchem.BondDir.values))
bond_stereo_vocab = range(len(Chem.rdchem.BondStereo.values))
# orderd by molecular mass
residue_vocab = ["GLY", "ALA", "SER", "PRO", "VAL", "THR", "CYS", "ILE", "LEU", "ASN",
"ASP", "GLN", "LYS", "GLU", "MET", "HIS", "PHE", "ARG", "TYR", "TRP"]
def onehot(x, vocab, allow_unknown=False):
if x in vocab:
if isinstance(vocab, dict):
index = vocab[x]
else:
index = vocab.index(x)
else:
index = -1
if allow_unknown:
feature = [0] * (len(vocab) + 1)
if index == -1:
warnings.warn("Unknown value `%s`" % x)
feature[index] = 1
else:
feature = [0] * len(vocab)
if index == -1:
raise ValueError("Unknown value `%s`. Available vocabulary is `%s`" % (x, vocab))
feature[index] = 1
return feature
# TODO: this one is too slow
@R.register("features.atom.default")
def atom_default(atom):
"""Default atom feature.
Features:
GetSymbol(): one-hot embedding for the atomic symbol
GetChiralTag(): one-hot embedding for atomic chiral tag
GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs
GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule
GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom
GetNumRadicalElectrons(): one-hot embedding for the number of radical electrons on the atom
GetHybridization(): one-hot embedding for the atom's hybridization
GetIsAromatic(): whether the atom is aromatic
IsInRing(): whether the atom is in a ring
"""
return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
onehot(atom.GetChiralTag(), chiral_tag_vocab) + \
onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \
onehot(atom.GetFormalCharge(), formal_charge_vocab) + \
onehot(atom.GetTotalNumHs(), num_hs_vocab) + \
onehot(atom.GetNumRadicalElectrons(), num_radical_vocab) + \
onehot(atom.GetHybridization(), hybridization_vocab) + \
[atom.GetIsAromatic(), atom.IsInRing()]
@R.register("features.atom.center_identification")
def atom_center_identification(atom):
"""Reaction center identification atom feature.
Features:
GetSymbol(): one-hot embedding for the atomic symbol
GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom
GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs
GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom
GetIsAromatic(): whether the atom is aromatic
IsInRing(): whether the atom is in a ring
"""
return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
onehot(atom.GetTotalNumHs(), num_hs_vocab) + \
onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \
onehot(atom.GetTotalValence(), total_valence_vocab) + \
[atom.GetIsAromatic(), atom.IsInRing()]
@R.register("features.atom.synthon_completion")
def atom_synthon_completion(atom):
"""Synthon completion atom feature.
Features:
GetSymbol(): one-hot embedding for the atomic symbol
GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom
GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs
IsInRing(): whether the atom is in a ring
IsInRingSize(3, 4, 5, 6): whether the atom is in a ring of a particular size
IsInRing() and not IsInRingSize(3, 4, 5, 6): whether the atom is in a ring and not in a ring of 3, 4, 5, 6
"""
return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
onehot(atom.GetTotalNumHs(), num_hs_vocab) + \
onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \
[atom.IsInRing(), atom.IsInRingSize(3), atom.IsInRingSize(4),
atom.IsInRingSize(5), atom.IsInRingSize(6),
atom.IsInRing() and (not atom.IsInRingSize(3)) and (not atom.IsInRingSize(4)) \
and (not atom.IsInRingSize(5)) and (not atom.IsInRingSize(6))]
@R.register("features.atom.symbol")
def atom_symbol(atom):
"""Symbol atom feature.
Features:
GetSymbol(): one-hot embedding for the atomic symbol
"""
return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True)
@R.register("features.atom.explicit_property_prediction")
def atom_explicit_property_prediction(atom):
"""Explicit property prediction atom feature.
Features:
GetSymbol(): one-hot embedding for the atomic symbol
GetDegree(): one-hot embedding for the degree of the atom in the molecule
GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom
GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule
GetIsAromatic(): whether the atom is aromatic
"""
return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
onehot(atom.GetDegree(), degree_vocab, allow_unknown=True) + \
onehot(atom.GetTotalValence(), total_valence_vocab, allow_unknown=True) + \
onehot(atom.GetFormalCharge(), formal_charge_vocab) + \
[atom.GetIsAromatic()]
@R.register("features.atom.property_prediction")
def atom_property_prediction(atom):
"""Property prediction atom feature.
Features:
GetSymbol(): one-hot embedding for the atomic symbol
GetDegree(): one-hot embedding for the degree of the atom in the molecule
GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom
GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom
GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule
GetIsAromatic(): whether the atom is aromatic
"""
return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
onehot(atom.GetDegree(), degree_vocab, allow_unknown=True) + \
onehot(atom.GetTotalNumHs(), num_hs_vocab, allow_unknown=True) + \
onehot(atom.GetTotalValence(), total_valence_vocab, allow_unknown=True) + \
onehot(atom.GetFormalCharge(), formal_charge_vocab, allow_unknown=True) + \
[atom.GetIsAromatic()]
@R.register("features.atom.position")
def atom_position(atom):
"""
Atom position in the molecular conformation.
Return 3D position if available, otherwise 2D position is returned.
Note it takes much time to compute the conformation for large molecules.
"""
mol = atom.GetOwningMol()
if mol.GetNumConformers() == 0:
mol.Compute2DCoords()
conformer = mol.GetConformer()
pos = conformer.GetAtomPosition(atom.GetIdx())
return [pos.x, pos.y, pos.z]
@R.register("features.atom.pretrain")
def atom_pretrain(atom):
"""Atom feature for pretraining.
Features:
GetSymbol(): one-hot embedding for the atomic symbol
GetChiralTag(): one-hot embedding for atomic chiral tag
"""
return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
onehot(atom.GetChiralTag(), chiral_tag_vocab)
@R.register("features.atom.residue_symbol")
def atom_residue_symbol(atom):
"""Residue symbol as atom feature. Only support atoms in a protein.
Features:
GetSymbol(): one-hot embedding for the atomic symbol
GetResidueName(): one-hot embedding for the residue symbol
"""
residue = atom.GetPDBResidueInfo()
return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
onehot(residue.GetResidueName() if residue else -1, residue_vocab, allow_unknown=True)
@R.register("features.bond.default")
def bond_default(bond):
"""Default bond feature.
Features:
GetBondType(): one-hot embedding for the type of the bond
GetBondDir(): one-hot embedding for the direction of the bond
GetStereo(): one-hot embedding for the stereo configuration of the bond
GetIsConjugated(): whether the bond is considered to be conjugated
"""
return onehot(bond.GetBondType(), bond_type_vocab) + \
onehot(bond.GetBondDir(), bond_dir_vocab) + \
onehot(bond.GetStereo(), bond_stereo_vocab) + \
[int(bond.GetIsConjugated())]
@R.register("features.bond.length")
def bond_length(bond):
"""
Bond length in the molecular conformation.
Note it takes much time to compute the conformation for large molecules.
"""
mol = bond.GetOwningMol()
if mol.GetNumConformers() == 0:
mol.Compute2DCoords()
conformer = mol.GetConformer()
h = conformer.GetAtomPosition(bond.GetBeginAtomIdx())
t = conformer.GetAtomPosition(bond.GetEndAtomIdx())
return [h.Distance(t)]
@R.register("features.bond.property_prediction")
def bond_property_prediction(bond):
"""Property prediction bond feature.
Features:
GetBondType(): one-hot embedding for the type of the bond
GetIsConjugated(): whether the bond is considered to be conjugated
IsInRing(): whether the bond is in a ring
"""
return onehot(bond.GetBondType(), bond_type_vocab) + \
[int(bond.GetIsConjugated()), bond.IsInRing()]
@R.register("features.bond.pretrain")
def bond_pretrain(bond):
"""Bond feature for pretraining.
Features:
GetBondType(): one-hot embedding for the type of the bond
GetBondDir(): one-hot embedding for the direction of the bond
"""
return onehot(bond.GetBondType(), bond_type_vocab) + \
onehot(bond.GetBondDir(), bond_dir_vocab)
@R.register("features.residue.symbol")
def residue_symbol(residue):
"""Symbol residue feature.
Features:
GetResidueName(): one-hot embedding for the residue symbol
"""
return onehot(residue.GetResidueName(), residue_vocab, allow_unknown=True)
@R.register("features.residue.default")
def residue_default(residue):
"""Default residue feature.
Features:
GetResidueName(): one-hot embedding for the residue symbol
"""
return residue_symbol(residue)
@R.register("features.molecule.ecfp")
def ExtendedConnectivityFingerprint(mol, radius=2, length=1024):
"""Extended Connectivity Fingerprint molecule feature.
Features:
GetMorganFingerprintAsBitVect(): a Morgan fingerprint for a molecule as a bit vector
"""
ecfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, length)
return list(ecfp)
@R.register("features.molecule.default")
def molecule_default(mol):
"""Default molecule feature."""
return ExtendedConnectivityFingerprint(mol)
ECFP = ExtendedConnectivityFingerprint
__all__ = [
"atom_default", "atom_center_identification", "atom_synthon_completion",
"atom_symbol", "atom_explicit_property_prediction", "atom_property_prediction",
"atom_position", "atom_pretrain", "atom_residue_symbol",
"bond_default", "bond_length", "bond_property_prediction", "bond_pretrain",
"residue_symbol", "residue_default",
"ExtendedConnectivityFingerprint", "molecule_default",
"ECFP",
]