Diff of /torchdrug/data/feature.py [000000] .. [36b44b]

Switch to side-by-side view

--- a
+++ b/torchdrug/data/feature.py
@@ -0,0 +1,347 @@
+import warnings
+
+from rdkit import Chem
+from rdkit.Chem import AllChem
+
+from torchdrug.core import Registry as R
+
+
+# orderd by perodic table
+atom_vocab = ["H", "B", "C", "N", "O", "F", "Mg", "Si", "P", "S", "Cl", "Cu", "Zn", "Se", "Br", "Sn", "I"]
+atom_vocab = {a: i for i, a in enumerate(atom_vocab)}
+degree_vocab = range(7)
+num_hs_vocab = range(7)
+formal_charge_vocab = range(-5, 6)
+chiral_tag_vocab = range(4)
+total_valence_vocab = range(8)
+num_radical_vocab = range(8)
+hybridization_vocab = range(len(Chem.rdchem.HybridizationType.values))
+
+bond_type_vocab = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
+                   Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
+bond_type_vocab = {b: i for i, b in enumerate(bond_type_vocab)}
+bond_dir_vocab = range(len(Chem.rdchem.BondDir.values))
+bond_stereo_vocab = range(len(Chem.rdchem.BondStereo.values))
+
+# orderd by molecular mass
+residue_vocab = ["GLY", "ALA", "SER", "PRO", "VAL", "THR", "CYS", "ILE", "LEU", "ASN",
+                 "ASP", "GLN", "LYS", "GLU", "MET", "HIS", "PHE", "ARG", "TYR", "TRP"]
+
+
+def onehot(x, vocab, allow_unknown=False):
+    if x in vocab:
+        if isinstance(vocab, dict):
+            index = vocab[x]
+        else:
+            index = vocab.index(x)
+    else:
+        index = -1
+    if allow_unknown:
+        feature = [0] * (len(vocab) + 1)
+        if index == -1:
+            warnings.warn("Unknown value `%s`" % x)
+        feature[index] = 1
+    else:
+        feature = [0] * len(vocab)
+        if index == -1:
+            raise ValueError("Unknown value `%s`. Available vocabulary is `%s`" % (x, vocab))
+        feature[index] = 1
+
+    return feature
+
+
+# TODO: this one is too slow
+@R.register("features.atom.default")
+def atom_default(atom):
+    """Default atom feature.
+
+    Features:
+        GetSymbol(): one-hot embedding for the atomic symbol
+        
+        GetChiralTag(): one-hot embedding for atomic chiral tag 
+        
+        GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs
+        
+        GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule
+        
+        GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom 
+        
+        GetNumRadicalElectrons(): one-hot embedding for the number of radical electrons on the atom
+        
+        GetHybridization(): one-hot embedding for the atom's hybridization
+        
+        GetIsAromatic(): whether the atom is aromatic
+        
+        IsInRing(): whether the atom is in a ring
+    """
+    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
+           onehot(atom.GetChiralTag(), chiral_tag_vocab) + \
+           onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \
+           onehot(atom.GetFormalCharge(), formal_charge_vocab) + \
+           onehot(atom.GetTotalNumHs(), num_hs_vocab) + \
+           onehot(atom.GetNumRadicalElectrons(), num_radical_vocab) + \
+           onehot(atom.GetHybridization(), hybridization_vocab) + \
+           [atom.GetIsAromatic(), atom.IsInRing()]
+
+
+@R.register("features.atom.center_identification")
+def atom_center_identification(atom):
+    """Reaction center identification atom feature.
+
+    Features:
+        GetSymbol(): one-hot embedding for the atomic symbol
+        
+        GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom 
+        
+        GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs
+        
+        GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom
+        
+        GetIsAromatic(): whether the atom is aromatic
+        
+        IsInRing(): whether the atom is in a ring
+    """
+    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
+           onehot(atom.GetTotalNumHs(), num_hs_vocab) + \
+           onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \
+           onehot(atom.GetTotalValence(), total_valence_vocab) + \
+           [atom.GetIsAromatic(), atom.IsInRing()]
+
+
+@R.register("features.atom.synthon_completion")
+def atom_synthon_completion(atom):
+    """Synthon completion atom feature.
+
+    Features:
+        GetSymbol(): one-hot embedding for the atomic symbol
+
+        GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom 
+        
+        GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs
+        
+        IsInRing(): whether the atom is in a ring
+        
+        IsInRingSize(3, 4, 5, 6): whether the atom is in a ring of a particular size
+        
+        IsInRing() and not IsInRingSize(3, 4, 5, 6): whether the atom is in a ring and not in a ring of 3, 4, 5, 6
+    """
+    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
+           onehot(atom.GetTotalNumHs(), num_hs_vocab) + \
+           onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \
+           [atom.IsInRing(), atom.IsInRingSize(3), atom.IsInRingSize(4),
+            atom.IsInRingSize(5), atom.IsInRingSize(6), 
+            atom.IsInRing() and (not atom.IsInRingSize(3)) and (not atom.IsInRingSize(4)) \
+            and (not atom.IsInRingSize(5)) and (not atom.IsInRingSize(6))]
+
+
+@R.register("features.atom.symbol")
+def atom_symbol(atom):
+    """Symbol atom feature.
+
+    Features:
+        GetSymbol(): one-hot embedding for the atomic symbol
+    """
+    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True)
+
+
+@R.register("features.atom.explicit_property_prediction")
+def atom_explicit_property_prediction(atom):
+    """Explicit property prediction atom feature.
+
+    Features:
+        GetSymbol(): one-hot embedding for the atomic symbol
+
+        GetDegree(): one-hot embedding for the degree of the atom in the molecule
+
+        GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom
+        
+        GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule
+        
+        GetIsAromatic(): whether the atom is aromatic
+    """
+    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
+           onehot(atom.GetDegree(), degree_vocab, allow_unknown=True) + \
+           onehot(atom.GetTotalValence(), total_valence_vocab, allow_unknown=True) + \
+           onehot(atom.GetFormalCharge(), formal_charge_vocab) + \
+           [atom.GetIsAromatic()]
+
+
+@R.register("features.atom.property_prediction")
+def atom_property_prediction(atom):
+    """Property prediction atom feature.
+
+    Features:
+        GetSymbol(): one-hot embedding for the atomic symbol
+        
+        GetDegree(): one-hot embedding for the degree of the atom in the molecule
+        
+        GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom 
+        
+        GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom
+        
+        GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule
+        
+        GetIsAromatic(): whether the atom is aromatic
+    """
+    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
+           onehot(atom.GetDegree(), degree_vocab, allow_unknown=True) + \
+           onehot(atom.GetTotalNumHs(), num_hs_vocab, allow_unknown=True) + \
+           onehot(atom.GetTotalValence(), total_valence_vocab, allow_unknown=True) + \
+           onehot(atom.GetFormalCharge(), formal_charge_vocab, allow_unknown=True) + \
+           [atom.GetIsAromatic()]
+
+
+@R.register("features.atom.position")
+def atom_position(atom):
+    """
+    Atom position in the molecular conformation.
+    Return 3D position if available, otherwise 2D position is returned.
+
+    Note it takes much time to compute the conformation for large molecules.
+    """
+    mol = atom.GetOwningMol()
+    if mol.GetNumConformers() == 0:
+        mol.Compute2DCoords()
+    conformer = mol.GetConformer()
+    pos = conformer.GetAtomPosition(atom.GetIdx())
+    return [pos.x, pos.y, pos.z]
+
+
+@R.register("features.atom.pretrain")
+def atom_pretrain(atom):
+    """Atom feature for pretraining.
+
+    Features:
+        GetSymbol(): one-hot embedding for the atomic symbol
+        
+        GetChiralTag(): one-hot embedding for atomic chiral tag
+    """
+    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
+           onehot(atom.GetChiralTag(), chiral_tag_vocab)
+
+
+@R.register("features.atom.residue_symbol")
+def atom_residue_symbol(atom):
+    """Residue symbol as atom feature. Only support atoms in a protein.
+
+    Features:
+        GetSymbol(): one-hot embedding for the atomic symbol
+        GetResidueName(): one-hot embedding for the residue symbol
+    """
+    residue = atom.GetPDBResidueInfo()
+    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
+           onehot(residue.GetResidueName() if residue else -1, residue_vocab, allow_unknown=True)
+
+
+@R.register("features.bond.default")
+def bond_default(bond):
+    """Default bond feature.
+
+    Features:
+        GetBondType(): one-hot embedding for the type of the bond
+        
+        GetBondDir(): one-hot embedding for the direction of the bond
+        
+        GetStereo(): one-hot embedding for the stereo configuration of the bond
+        
+        GetIsConjugated(): whether the bond is considered to be conjugated
+    """
+    return onehot(bond.GetBondType(), bond_type_vocab) + \
+           onehot(bond.GetBondDir(), bond_dir_vocab) + \
+           onehot(bond.GetStereo(), bond_stereo_vocab) + \
+           [int(bond.GetIsConjugated())]
+
+
+@R.register("features.bond.length")
+def bond_length(bond):
+    """
+    Bond length in the molecular conformation.
+
+    Note it takes much time to compute the conformation for large molecules.
+    """
+    mol = bond.GetOwningMol()
+    if mol.GetNumConformers() == 0:
+        mol.Compute2DCoords()
+    conformer = mol.GetConformer()
+    h = conformer.GetAtomPosition(bond.GetBeginAtomIdx())
+    t = conformer.GetAtomPosition(bond.GetEndAtomIdx())
+    return [h.Distance(t)]
+
+
+@R.register("features.bond.property_prediction")
+def bond_property_prediction(bond):
+    """Property prediction bond feature.
+
+    Features:
+        GetBondType(): one-hot embedding for the type of the bond
+        
+        GetIsConjugated(): whether the bond is considered to be conjugated
+        
+        IsInRing(): whether the bond is in a ring
+    """
+    return onehot(bond.GetBondType(), bond_type_vocab) + \
+           [int(bond.GetIsConjugated()), bond.IsInRing()]
+
+
+@R.register("features.bond.pretrain")
+def bond_pretrain(bond):
+    """Bond feature for pretraining.
+
+    Features:
+        GetBondType(): one-hot embedding for the type of the bond
+        
+        GetBondDir(): one-hot embedding for the direction of the bond
+    """
+    return onehot(bond.GetBondType(), bond_type_vocab) + \
+           onehot(bond.GetBondDir(), bond_dir_vocab)
+
+
+@R.register("features.residue.symbol")
+def residue_symbol(residue):
+    """Symbol residue feature.
+
+    Features:
+        GetResidueName(): one-hot embedding for the residue symbol
+    """
+    return onehot(residue.GetResidueName(), residue_vocab, allow_unknown=True)
+
+
+@R.register("features.residue.default")
+def residue_default(residue):
+    """Default residue feature.
+
+    Features:
+        GetResidueName(): one-hot embedding for the residue symbol
+    """
+    return residue_symbol(residue)
+
+
+@R.register("features.molecule.ecfp")
+def ExtendedConnectivityFingerprint(mol, radius=2, length=1024):
+    """Extended Connectivity Fingerprint molecule feature.
+
+    Features:
+        GetMorganFingerprintAsBitVect(): a Morgan fingerprint for a molecule as a bit vector
+    """
+    ecfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, length)
+    return list(ecfp)
+
+
+@R.register("features.molecule.default")
+def molecule_default(mol):
+    """Default molecule feature."""
+    return ExtendedConnectivityFingerprint(mol)
+
+
+ECFP = ExtendedConnectivityFingerprint
+
+
+__all__ = [
+    "atom_default", "atom_center_identification", "atom_synthon_completion",
+    "atom_symbol", "atom_explicit_property_prediction", "atom_property_prediction",
+    "atom_position", "atom_pretrain", "atom_residue_symbol",
+    "bond_default", "bond_length", "bond_property_prediction", "bond_pretrain",
+    "residue_symbol", "residue_default",
+    "ExtendedConnectivityFingerprint", "molecule_default",
+    "ECFP",
+]
\ No newline at end of file